diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 36fb0f91e4c..b2d6a43731f 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -45,6 +45,10 @@ namespace c10 { _(prim, CudaFusionGuard) \ _(prim, FunctionalGraph) \ _(prim, add_optional) \ + _(prim, view_copy) \ + _(prim, reshape_copy) \ + _(prim, squeeze_copy) \ + _(prim, unsqueeze_copy) \ _(prim, DifferentiableGraph) \ _(prim, TensorExprGroup) \ _(prim, TensorExprDynamicGroup) \ diff --git a/benchmarks/cpp/nvfuser/batch_norm.cpp b/benchmarks/cpp/nvfuser/batch_norm.cpp index ef6bdd667d6..57e889b19fb 100644 --- a/benchmarks/cpp/nvfuser/batch_norm.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -44,8 +45,8 @@ static void setupBatchNorm(Fusion* fusion, DataType dtype) { bias = castOp(DataType::Float, bias); } - auto momentum_ptr = new Double(kMomentum); - auto eps_ptr = new Double(kEps); + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); auto result = batch_norm( input, diff --git a/benchmarks/cpp/nvfuser/batch_norm_backward.cpp b/benchmarks/cpp/nvfuser/batch_norm_backward.cpp index e4a9fdcb034..77a09564de5 100644 --- a/benchmarks/cpp/nvfuser/batch_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm_backward.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -49,7 +50,7 @@ static void setupBatchNorm_BWD(Fusion* fusion, DataType dtype) { grad_output = castOp(DataType::Float, grad_output); } - auto eps_ptr = new Double(kEps); + auto eps_ptr = IrBuilder::create(kEps); auto result = batch_norm_backward( input, diff --git a/benchmarks/cpp/nvfuser/bert.cpp b/benchmarks/cpp/nvfuser/bert.cpp index f8a389331ee..a1dd58d5646 100644 --- a/benchmarks/cpp/nvfuser/bert.cpp +++ b/benchmarks/cpp/nvfuser/bert.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -36,7 +37,7 @@ static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) { fusion->addInput(tv1); // TODO: should be input - auto d16 = new Double(1.0); + auto d16 = IrBuilder::create(1.0); if (is_fp16) { tv0 = castOp(DataType::Float, tv0); @@ -47,7 +48,7 @@ static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) { auto tv3 = add(tv2, tv0); auto tv10 = softmax(tv3, 3); - auto dropout_tvs = dropout(tv10, new Double(0.9)); + auto dropout_tvs = dropout(tv10, IrBuilder::create(0.9)); auto tv12 = dropout_tvs.mask; auto tv14 = dropout_tvs.output; @@ -83,9 +84,9 @@ static void setupDivMaxSoftmaxDropoutBackward(Fusion* fusion, DataType dtype) { } // TODO: should be inputs - auto d32 = new Double(1.0); + auto d32 = IrBuilder::create(1.0); // fusion->addInput(d32); - auto d33 = new Double(2.0); + auto d33 = IrBuilder::create(2.0); // fusion->addInput(d33); auto tv4 = mul(tv2, tv3); @@ -252,14 +253,15 @@ static void setupBiasDropoutAddLayernormFwd(Fusion* fusion, DataType dtype) { auto tv5 = broadcast(tv4, {true, true, false}); auto tv6 = add(tv3, tv5); - auto dropout_outs = dropout(tv6, new Double(0.9)); + auto dropout_outs = dropout(tv6, IrBuilder::create(0.9)); auto tv8 = dropout_outs.output; auto tv10 = dropout_outs.mask; auto tv11 = add(tv10, tv2); - auto layer_norm_outs = layer_norm(tv11, 1, tv0, tv1, new Double(1e-5)); + auto layer_norm_outs = + layer_norm(tv11, 1, tv0, tv1, IrBuilder::create(1e-5)); auto tv14 = layer_norm_outs.output; auto tv21 = layer_norm_outs.mean; auto tv26 = layer_norm_outs.invstd; @@ -481,7 +483,7 @@ static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) { tv1 = castOp(DataType::Float, tv1); tv8 = castOp(DataType::Float, tv8); } - auto d36 = mul(new Double(1.0), tv1->axis(2)->extent()); + auto d36 = mul(IrBuilder::create(1.0), tv1->axis(2)->extent()); auto d47 = unaryOp(UnaryOpType::Reciprocal, d36); auto tv9 = broadcast(tv5, {true, true, false}); @@ -583,7 +585,7 @@ static void setupBiasDropoutAddLayernormBwd3(Fusion* fusion, DataType dtype) { } // Uncertain this is the right value, but going for it anyways - auto d34 = div(new Double(1.0), tv0->axis(2)->extent()); + auto d34 = div(IrBuilder::create(1.0), tv0->axis(2)->extent()); auto tv25 = mul(tv21, tv0); auto tv26 = mul(tv25, d34); diff --git a/benchmarks/cpp/nvfuser/gelu_backward.cpp b/benchmarks/cpp/nvfuser/gelu_backward.cpp index 9d53d9c2759..f1811795462 100644 --- a/benchmarks/cpp/nvfuser/gelu_backward.cpp +++ b/benchmarks/cpp/nvfuser/gelu_backward.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -41,23 +42,23 @@ static void setupFusion(Fusion* fusion) { auto t5 = castOp(DataType::Float, t4); auto t6 = broadcast(t3, {true, true, false}); auto t7 = add(t6, t5); - auto t8 = mul(t7, new Double(k_079)); - auto t9 = mul(t7, new Double(k_004)); + auto t8 = mul(t7, IrBuilder::create(k_079)); + auto t9 = mul(t7, IrBuilder::create(k_004)); auto t10 = mul(t9, t7); - auto t11 = add(t10, new Int(1)); + auto t11 = add(t10, IrBuilder::create(1)); auto t12 = mul(t8, t11); auto t13 = unaryOp(UnaryOpType::Tanh, t12); - auto t14 = mul(t7, new Double(0.5)); + auto t14 = mul(t7, IrBuilder::create(0.5)); auto t15 = mul(t13, t13); auto t16 = unaryOp(UnaryOpType::Neg, t15); - auto t17 = add(t16, new Int(1)); - auto t18 = mul(t7, new Double(k_010)); + auto t17 = add(t16, IrBuilder::create(1)); + auto t18 = mul(t7, IrBuilder::create(k_010)); auto t19 = mul(t18, t7); - auto t20 = add(t19, new Double(k_079)); + auto t20 = add(t19, IrBuilder::create(k_079)); auto t21 = mul(t17, t20); auto t22 = mul(t14, t21); - auto t23 = add(t13, new Int(1)); - auto t24 = mul(t23, new Double(0.5)); + auto t23 = add(t13, IrBuilder::create(1)); + auto t24 = mul(t23, IrBuilder::create(0.5)); auto t25 = add(t22, t24); auto t26 = mul(t25, t1); diff --git a/benchmarks/cpp/nvfuser/heuristic_cache.cpp b/benchmarks/cpp/nvfuser/heuristic_cache.cpp index 22b8ec4ce97..65f850a016c 100644 --- a/benchmarks/cpp/nvfuser/heuristic_cache.cpp +++ b/benchmarks/cpp/nvfuser/heuristic_cache.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -129,7 +130,7 @@ static auto getLayerForwardNormRuntime( Fusion& fusion = *fusion_ptr.get(); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); auto input = makeSymbolicTensor(shape.size()); fusion.addInput(input); diff --git a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp index 22b8ec4ce97..65f850a016c 100644 --- a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp +++ b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -129,7 +130,7 @@ static auto getLayerForwardNormRuntime( Fusion& fusion = *fusion_ptr.get(); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); auto input = makeSymbolicTensor(shape.size()); fusion.addInput(input); diff --git a/benchmarks/cpp/nvfuser/instance_norm.cpp b/benchmarks/cpp/nvfuser/instance_norm.cpp index 395ac6c8c9c..007291d75f5 100644 --- a/benchmarks/cpp/nvfuser/instance_norm.cpp +++ b/benchmarks/cpp/nvfuser/instance_norm.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -39,8 +40,8 @@ static void setupInstanceNorm(Fusion* fusion, DataType dtype) { const bool kTraining = true; const float kMomentum = 0.1; const float kEps = 1e-5; - auto momentum_ptr = new Double(kMomentum); - auto eps_ptr = new Double(kEps); + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); auto norm = instance_norm( input, diff --git a/benchmarks/cpp/nvfuser/layer_norm.cpp b/benchmarks/cpp/nvfuser/layer_norm.cpp index c4f79b2b668..7500ac8525b 100644 --- a/benchmarks/cpp/nvfuser/layer_norm.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -24,7 +25,7 @@ static void setupLayerNorm(Fusion* fusion, DataType dtype) { const int kReductionAxis = 1; const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); // setup fusion auto input = makeContigTensor(2, dtype); diff --git a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp index 43eafcc42fb..045465e7125 100644 --- a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -22,7 +23,7 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); const int kReductionAxis = 1; - Double* eps_ptr = new Double(1e-5); + Double* eps_ptr = IrBuilder::create(1e-5); // setup fusion auto grad_out = makeContigTensor(2, dtype); diff --git a/benchmarks/cpp/nvfuser/shape_inference.cpp b/benchmarks/cpp/nvfuser/shape_inference.cpp index 33a9404b073..15acc51bb37 100644 --- a/benchmarks/cpp/nvfuser/shape_inference.cpp +++ b/benchmarks/cpp/nvfuser/shape_inference.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -151,7 +152,7 @@ static auto getLayerForwardNormRuntime( Fusion& fusion = *fusion_ptr.get(); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); auto input = makeSymbolicTensor(shape.size()); fusion.addInput(input); diff --git a/benchmarks/cpp/nvfuser/softmax_dropout.cpp b/benchmarks/cpp/nvfuser/softmax_dropout.cpp index b4890eaf8d8..828940933f4 100644 --- a/benchmarks/cpp/nvfuser/softmax_dropout.cpp +++ b/benchmarks/cpp/nvfuser/softmax_dropout.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -35,7 +36,7 @@ static void setupSoftmaxDropout( auto attention_scores = makeContigTensor(4, dtype); auto attention_mask = makeContigTensor(4, dtype); - Double* divisor = new Double(); + Double* divisor = IrBuilder::create(); fusion->addInput(attention_scores); fusion->addInput(attention_mask); @@ -49,8 +50,8 @@ static void setupSoftmaxDropout( attention_scores = div(attention_scores, divisor); attention_scores = add(attention_scores, attention_mask); auto attention_probs = softmax(attention_scores, kReductionAxis); - auto prob = new Double(kDropoutProbability); - auto scale = new Double(kScale); + auto prob = IrBuilder::create(kDropoutProbability); + auto scale = IrBuilder::create(kScale); auto dropout_results = dropout(attention_probs, prob, scale); auto output = dropout_results.output; diff --git a/benchmarks/cpp/nvfuser/utils.cpp b/benchmarks/cpp/nvfuser/utils.cpp index 053fc693908..daf2b21a053 100644 --- a/benchmarks/cpp/nvfuser/utils.cpp +++ b/benchmarks/cpp/nvfuser/utils.cpp @@ -16,8 +16,8 @@ std::string toString(ReductionParams rparams) { if (rparams.schedule_3D) { ss << "3D Schedule // " << "Outer Reduction: " - << (rparams.cross_block_outer_reduce ? "cross block / " : "") - << (rparams.cross_grid_outer_reduce ? "cross grid / " : "") + << (rparams.cross_block_outer_reduction ? "cross block / " : "") + << (rparams.cross_grid_outer_reduction ? "cross grid / " : "") << (rparams.split_grid_dim_outer_reduction ? "split grid dim / " : ""); if (rparams.batches_per_block_outer_reduction > 1 || rparams.persistent_kernel) { @@ -38,9 +38,9 @@ std::string toString(ReductionParams rparams) { } ss << " // Inner Reduction Domain: " - << (rparams.cross_block_inner_reduce ? "cross block reduction / " : "") + << (rparams.cross_block_inner_reduction ? "cross block reduction / " : "") << (rparams.pad_inner_reduction_to_warp ? "pad to warp / " : "") - << (rparams.cross_grid_inner_reduce ? "cross grid reduction / " : ""); + << (rparams.cross_grid_inner_reduction ? "cross grid reduction / " : ""); if (rparams.batches_per_block_inner_reduction > 1 || rparams.persistent_kernel) { @@ -48,7 +48,7 @@ std::string toString(ReductionParams rparams) { << " / "; } - ss << (rparams.cross_grid_inner_reduce && + ss << (rparams.cross_grid_inner_reduction && rparams.split_grid_dim_inner_reduction ? "split grid dimension / " : "") diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index f229ac2679e..b7a4489abe9 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -18,7 +19,7 @@ #include #include #include -#include +#include #include #include #include @@ -86,19 +87,95 @@ void checkIntValue( void checkIntValue( kir::ExpressionEvaluator& evaluator, - const kir::Val* val, - kir::Int::ScalarType expected_value) { + const Val* val, + Int::ScalarType expected_value) { const auto actual_value = evaluator.evaluate(val); TORCH_CHECK(actual_value.has_value()); TORCH_CHECK(actual_value.value() == expected_value); } -bool isPredicated(TensorView* tv, GpuLower& gpulw) { - auto parent_scope = gpulw.lowerValue(tv)->definition()->parentScope(); - if (parent_scope->isA()) { - return !parent_scope->predicate()->value()->isConst(); +TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) { + auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as()); + TensorView* matching_tv = nullptr; + for (auto lowered_tv : used_tvs) { + if (lowered_tv->name() == tv->name()) { + matching_tv = lowered_tv; + } + } + TORCH_INTERNAL_ASSERT(matching_tv != nullptr); + return matching_tv; +} + +class PredicatedChecker : public kir::IrVisitor { + public: + // Checks if the provided tv is written to within a non-trivial conditional + static bool isPredicated(TensorView* tv, GpuLower& gpulw) { + PredicatedChecker checker( + loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); + return checker.is_predicated_; + } + + private: + PredicatedChecker() = delete; + + PredicatedChecker(TensorView* tv, std::vector exprs) : tv_(tv) { + kir::IrVisitor::handle(exprs); + } + + using kir::IrVisitor::handle; + bool is_predicated_ = false; + bool predicated_ite_ = false; + TensorView* tv_ = nullptr; + + void handle(kir::IfThenElse* ite) final { + auto prev_ite = predicated_ite_; + predicated_ite_ = !ite->predicate()->value()->isConstScalar(); + kir::IrVisitor::handle(ite); + predicated_ite_ = prev_ite; + } + + void handle(Expr* expr) final { + if (expr->outputs().size() && expr->outputs()[0]->isA()) { + auto ti = expr->outputs()[0]->as(); + if (ti->view() == tv_) { + is_predicated_ = is_predicated_ | predicated_ite_; + } + } + kir::IrVisitor::handle(expr); + } +}; + +class UnswitchInElseChecker : public kir::IrVisitor { + public: + // Checks if there are any unswitched for loops within an else clause + static bool check(GpuLower& gpulw) { + UnswitchInElseChecker checker(gpulw.kernel()->topLevelExprs()); + return checker.found_in_else_; + } + + private: + UnswitchInElseChecker() = delete; + UnswitchInElseChecker(std::vector exprs) { + kir::IrVisitor::handle(exprs); + } + + using kir::IrVisitor::handle; + bool within_else_ = false; + bool found_in_else_ = false; + + void handle(kir::IfThenElse* ite) final { + auto prev_within_else = within_else_; + within_else_ = true; + kir::IrVisitor::handle(ite->elseBody().exprs()); + within_else_ = prev_within_else; + } + + void handle(kir::ForLoop* for_loop) final { + if (for_loop->iter_domain()->getParallelType() == ParallelType::Unswitch) { + found_in_else_ = found_in_else_ || within_else_; + } + kir::IrVisitor::handle(for_loop); } - return true; }; } // namespace @@ -110,7 +187,7 @@ bool isPredicated(TensorView* tv, GpuLower& gpulw) { // (These tests exercise IrGraphGenerator through a non-trivial IR, // to make sure that it runs w/o crashing. The actual output is not // validated) -TEST(NVFuserTest, IrGraphGenerator_CUDA) { +TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -123,10 +200,12 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv2 = add(tv0, new Double(3.141)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.141)); TensorView* tv3 = broadcast(tv0, {false, true, false, true}); - TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv3); - TensorView* tv5 = clamp(tv4, new Double(0.f), new Double(1.f)); + TensorView* tv4 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv3); + TensorView* tv5 = clamp( + tv4, IrBuilder::create(0.f), IrBuilder::create(1.f)); TensorView* tv6 = add(tv2, tv2); // Another checkpoint before adding outputs @@ -149,7 +228,7 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) { .empty()); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(-1)->parallelize(ParallelType::TIDx); @@ -162,11 +241,11 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) { .empty()); } -TEST(NVFuserTest, FusionDispatch_CUDA) { +TEST_F(NVFuserTest, FusionDispatch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* f = new Double{2.f}; + Double* f = IrBuilder::create(2.f); std::stringstream ss1, ss2, ss3; ss1 << f; ss2 << static_cast(f); @@ -177,14 +256,14 @@ TEST(NVFuserTest, FusionDispatch_CUDA) { } // Evaluate basic scalar operations with constant values -TEST(NVFuserTest, FusionExprEvalConstants_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalConstants_CUDA) { Fusion fusion; FusionGuard fg(&fusion); ExpressionEvaluator evaluator(&fusion); - auto* a = new Int(7); - auto* b = new Int(3); + auto* a = IrBuilder::create(7); + auto* b = IrBuilder::create(3); // Avoid div operation because it casts int operands to float checkIntValue(evaluator, neg(a), -7); @@ -195,17 +274,17 @@ TEST(NVFuserTest, FusionExprEvalConstants_CUDA) { } // Evaluate basic scalar operations with bound values -TEST(NVFuserTest, FusionExprEvalBindings_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalBindings_CUDA) { Fusion fusion; FusionGuard fg(&fusion); ExpressionEvaluator evaluator(&fusion); - auto* a = new Int(); - auto* b = new Int(); + auto* a = IrBuilder::create(); + auto* b = IrBuilder::create(); auto* c = add(a, b); auto* d = neg(ceilDiv(c, b)); - auto* e = new Int(0); + auto* e = IrBuilder::create(0); // trying to evaluate before binding should give empty results TORCH_CHECK(!evaluator.evaluate(a).has_value()); @@ -240,7 +319,7 @@ TEST(NVFuserTest, FusionExprEvalBindings_CUDA) { } // Evaluate expressions in a simple IR -TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -251,7 +330,7 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); @@ -296,16 +375,16 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { } // Evaluate expressions in a more complex IR -TEST(NVFuserTest, FusionExprEvalComplex_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalComplex_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); TensorView* tv6 = add(tv0, tv3); @@ -348,7 +427,7 @@ TEST(NVFuserTest, FusionExprEvalComplex_CUDA) { } // Evaluate expressions post lowering -TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalPostLower_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -359,7 +438,7 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); @@ -375,8 +454,8 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); - auto* bid_x = add(tv3->axis(0)->extent(), new Int(0)); - auto* tid_x = add(tv3->axis(-1)->extent(), new Int(0)); + auto* bid_x = add(tv3->axis(0)->extent(), IrBuilder::create(0)); + auto* tid_x = add(tv3->axis(-1)->extent(), IrBuilder::create(0)); // Lower GpuLower gpulw(&fusion); @@ -406,37 +485,39 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { } // Kernel IR: Evaluate basic scalar operations with constant values -TEST(NVFuserTest, FusionKernelExprEvalConstants_CUDA) { - kir::Kernel kernel; - kir::IrBuilder ir_builder(&kernel); +TEST_F(NVFuserTest, FusionKernelExprEvalConstants_CUDA) { + Fusion fusion; + kir::Kernel kernel(&fusion); + FusionGuard fg((&kernel)->as()); - auto a = ir_builder.create(7); - auto b = ir_builder.create(3); - auto c = ir_builder.subExpr(a, b); - auto d = ir_builder.divExpr(a, b); - auto e = ir_builder.mulExpr(c, d); + auto a = IrBuilder::create(7); + auto b = IrBuilder::create(3); + auto c = IrBuilder::subExpr(a, b); + auto d = IrBuilder::divExpr(a, b); + auto e = IrBuilder::mulExpr(c, d); kir::ExpressionEvaluator evaluator; - checkIntValue(evaluator, ir_builder.negExpr(a), -7); - checkIntValue(evaluator, ir_builder.addExpr(a, b), 10); - checkIntValue(evaluator, ir_builder.negExpr(e), -8); - checkIntValue(evaluator, ir_builder.modExpr(a, b), 1); - checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3); + checkIntValue(evaluator, IrBuilder::negExpr(a), -7); + checkIntValue(evaluator, IrBuilder::addExpr(a, b), 10); + checkIntValue(evaluator, IrBuilder::negExpr(e), -8); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); } // Kernel IR: Evaluate basic scalar operations with bound values -TEST(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { - kir::Kernel kernel; - kir::IrBuilder ir_builder(&kernel); +TEST_F(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { + Fusion fusion; + kir::Kernel kernel(&fusion); + FusionGuard fg((&kernel)->as()); kir::ExpressionEvaluator evaluator; - auto a = ir_builder.create(c10::nullopt); - auto b = ir_builder.create(c10::nullopt); - auto c = ir_builder.addExpr(a, b); - auto d = ir_builder.negExpr(ir_builder.ceilDivExpr(c, b)); - auto e = ir_builder.create(0); + auto a = IrBuilder::create(c10::nullopt); + auto b = IrBuilder::create(c10::nullopt); + auto c = IrBuilder::addExpr(a, b); + auto d = IrBuilder::negExpr(IrBuilder::ceilDivExpr(c, b)); + auto e = IrBuilder::create(0); // trying to evaluate before binding should give empty results TORCH_CHECK(!evaluator.evaluate(a).has_value()); @@ -452,9 +533,9 @@ TEST(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { ASSERT_ANY_THROW(evaluator.bind(e, 100)); checkIntValue(evaluator, c, 10); - checkIntValue(evaluator, ir_builder.subExpr(a, b), 4); - checkIntValue(evaluator, ir_builder.modExpr(a, b), 1); - checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3); + checkIntValue(evaluator, IrBuilder::subExpr(a, b), 4); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); checkIntValue(evaluator, d, -4); // Reset the evaluation context @@ -464,13 +545,13 @@ TEST(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { evaluator.bind(b, 5); checkIntValue(evaluator, c, 7); - checkIntValue(evaluator, ir_builder.subExpr(a, b), -3); - checkIntValue(evaluator, ir_builder.modExpr(a, b), 2); - checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::subExpr(a, b), -3); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 2); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 1); checkIntValue(evaluator, d, -2); } -TEST(NVFuserTest, FusionClear_CUDA) { +TEST_F(NVFuserTest, FusionClear_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -483,7 +564,7 @@ TEST(NVFuserTest, FusionClear_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); @@ -507,14 +588,14 @@ TEST(NVFuserTest, FusionClear_CUDA) { TORCH_CHECK(fusion.inputs().empty()); TORCH_CHECK(fusion.outputs().empty()); - TORCH_CHECK(!fusion.hasReduction()); + TORCH_CHECK(ir_utils::getReductionOps(&fusion).empty()); // 3. Rebuild the IR { TensorView* tv0 = makeSymbolicTensor(3); TensorView* tv1 = makeSymbolicTensor(3); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addInput(tv0); @@ -539,7 +620,7 @@ TEST(NVFuserTest, FusionClear_CUDA) { at::Tensor input2 = at::randn_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; @@ -548,7 +629,7 @@ TEST(NVFuserTest, FusionClear_CUDA) { TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionCopy_CUDA) { +TEST_F(NVFuserTest, FusionCopy_CUDA) { Fusion original_fusion; // Create the test IR @@ -557,7 +638,7 @@ TEST(NVFuserTest, FusionCopy_CUDA) { auto tv0 = makeSymbolicTensor(3); auto tv1 = makeSymbolicTensor(3); - auto tv2 = add(tv1, new Double(2.0)); + auto tv2 = add(tv1, IrBuilder::create(2.0)); auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); original_fusion.addInput(tv0); @@ -622,7 +703,7 @@ TEST(NVFuserTest, FusionCopy_CUDA) { ASSERT_EQ(original_kernel, clone_kernel); } -TEST(NVFuserTest, FusionMove_CUDA) { +TEST_F(NVFuserTest, FusionMove_CUDA) { Fusion fusion; // Create the test IR @@ -631,7 +712,7 @@ TEST(NVFuserTest, FusionMove_CUDA) { auto tv0 = makeSymbolicTensor(3); auto tv1 = makeSymbolicTensor(3); - auto tv2 = add(tv1, new Double(2.0)); + auto tv2 = add(tv1, IrBuilder::create(2.0)); auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); fusion.addInput(tv0); @@ -692,28 +773,28 @@ TEST(NVFuserTest, FusionMove_CUDA) { ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str()); } -TEST(NVFuserTest, FusionSimpleArith_CUDA) { +TEST_F(NVFuserTest, FusionSimpleArith_CUDA) { std::stringstream ss1, ss2; Fusion fusion; FusionGuard fg(&fusion); - Double* d1 = new Double(1.f); - Double* d2 = new Double{2.f}; - Double* d3 = new Double(); + Double* d1 = IrBuilder::create(1.f); + Double* d2 = IrBuilder::create(2.f); + Double* d3 = IrBuilder::create(); // Disrupt the fusion to make sure guard works well { Fusion fusion2; FusionGuard fg(&fusion2); - Double* d1 = new Double(1.f); - Double* d2 = new Double(2.f); + Double* d1 = IrBuilder::create(1.f); + Double* d2 = IrBuilder::create(2.f); add(d1, d2); ss2 << fusion2; } - new BinaryOp(BinaryOpType::Add, d3, d1, d2); + IrBuilder::create(BinaryOpType::Add, d3, d1, d2); ss1 << fusion; TORCH_CHECK( @@ -721,22 +802,22 @@ TEST(NVFuserTest, FusionSimpleArith_CUDA) { "Error where explicit add nodes don't match implicit add nodes."); } -TEST(NVFuserTest, FusionSimpleTypePromote_CUDA) { +TEST_F(NVFuserTest, FusionSimpleTypePromote_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* d4 = new Double{4.f}; - Int* i1 = new Int{3}; + Double* d4 = IrBuilder::create(4.f); + Int* i1 = IrBuilder::create(3); auto d5 = add(d4, i1); TORCH_CHECK(d5->getDataType() == DataType::Double); } -TEST(NVFuserTest, FusionRegister_CUDA) { +TEST_F(NVFuserTest, FusionRegister_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* v1 = new Double{1.f}; - Double* v2 = new Double{2.f}; + Double* v1 = IrBuilder::create(1.f); + Double* v2 = IrBuilder::create(2.f); Val* v3 = binaryOp(BinaryOpType::Add, v1, v2); Val* v4 = binaryOp(BinaryOpType::Add, v1, v2); TORCH_CHECK(v1->name() + 1 == v2->name()); @@ -748,14 +829,18 @@ TEST(NVFuserTest, FusionRegister_CUDA) { // dummy expr with 2 outputs only for toposort test. struct DummyExpr : public Expr { ~DummyExpr() = default; - DummyExpr(Val* _outlhs, Val* _outrhs, Val* _lhs, Val* _rhs) - : Expr(ExprType::UnaryOp) // Not terribly safe... + DummyExpr( + IrBuilderPasskey passkey, + Val* _outlhs, + Val* _outrhs, + Val* _lhs, + Val* _rhs) + : Expr(passkey, ExprType::UnaryOp) // Not terribly safe... { addOutput(_outlhs); addOutput(_outrhs); addInput(_lhs); addInput(_rhs); - this->name_ = FusionGuard::getCurFusion()->registerExpr(this); } DummyExpr(const DummyExpr& other) = delete; DummyExpr& operator=(const DummyExpr& other) = delete; @@ -763,7 +848,7 @@ struct DummyExpr : public Expr { DummyExpr& operator=(DummyExpr&& other) = delete; }; -TEST(NVFuserTest, FusionTopoSort_CUDA) { +TEST_F(NVFuserTest, FusionTopoSort_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -771,23 +856,23 @@ TEST(NVFuserTest, FusionTopoSort_CUDA) { // e1: v4 = add(v3, v2) // e2: v5 = add(v2, v4) // e3: v6 = add(v5, v5) - Double* v0 = new Double{1.f}; - Double* v1 = new Double{2.f}; - Double* v2 = new Double(); - Double* v3 = new Double(); - Double* v4 = new Double(); - Double* v5 = new Double(); - Double* v6 = new Double(); + Double* v0 = IrBuilder::create(1.f); + Double* v1 = IrBuilder::create(2.f); + Double* v2 = IrBuilder::create(); + Double* v3 = IrBuilder::create(); + Double* v4 = IrBuilder::create(); + Double* v5 = IrBuilder::create(); + Double* v6 = IrBuilder::create(); std::vector inputs = {v0, v1}; for (auto val : inputs) { fusion.addInput(val); } - Expr* e0 = new DummyExpr(v3, v2, v1, v0); - Expr* e1 = new BinaryOp(BinaryOpType::Add, v4, v3, v2); - Expr* e2 = new BinaryOp(BinaryOpType::Add, v5, v2, v4); - Expr* e3 = new BinaryOp(BinaryOpType::Add, v6, v5, v5); + Expr* e0 = IrBuilder::create(v3, v2, v1, v0); + Expr* e1 = IrBuilder::create(BinaryOpType::Add, v4, v3, v2); + Expr* e2 = IrBuilder::create(BinaryOpType::Add, v5, v2, v4); + Expr* e3 = IrBuilder::create(BinaryOpType::Add, v6, v5, v5); fusion.addOutput(v2); fusion.addOutput(v3); @@ -824,7 +909,7 @@ TEST(NVFuserTest, FusionTopoSort_CUDA) { TORCH_CHECK(v6->definition()->name() == 3); } -TEST(NVFuserTest, FusionTensor_CUDA) { +TEST_F(NVFuserTest, FusionTensor_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); Fusion fusion; @@ -833,7 +918,7 @@ TEST(NVFuserTest, FusionTensor_CUDA) { { auto tensor = at::randn({2, 3, 4, 5}, options); auto tensor_type = TensorType::create(tensor); - auto fuser_tensor = new TensorView(tensor_type); + auto fuser_tensor = IrBuilder::create(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); @@ -856,7 +941,7 @@ TEST(NVFuserTest, FusionTensor_CUDA) { auto sliced_tensor = tensor.slice(1, 0, -1, 2); auto tensor_type = TensorType::create(sliced_tensor); - auto fuser_tensor = new TensorView(tensor_type); + auto fuser_tensor = IrBuilder::create(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); @@ -873,7 +958,7 @@ TEST(NVFuserTest, FusionTensor_CUDA) { auto tensor = at::randn({2, 3, 4, 5}, options); auto permuted_tensor = tensor.permute({0, 3, 1, 2}); auto tensor_type = TensorType::create(permuted_tensor); - auto fuser_tensor = new TensorView(tensor_type); + auto fuser_tensor = IrBuilder::create(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); @@ -888,15 +973,15 @@ TEST(NVFuserTest, FusionTensor_CUDA) { } } -TEST(NVFuserTest, FusionFilterVals_CUDA) { +TEST_F(NVFuserTest, FusionFilterVals_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); auto tv1 = makeSymbolicTensor(1); - auto scalar0 = new Double(0); - auto scalar1 = new Int(0); - auto scalar2 = new Int(1); + auto scalar0 = IrBuilder::create(0); + auto scalar1 = IrBuilder::create(0); + auto scalar2 = IrBuilder::create(1); const std::vector vals = {tv0, scalar0, tv1, scalar1, scalar2}; @@ -926,7 +1011,7 @@ TEST(NVFuserTest, FusionFilterVals_CUDA) { "Not expecting any results"); } -TEST(NVFuserTest, FusionTVSplit_CUDA) { +TEST_F(NVFuserTest, FusionTVSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -943,7 +1028,7 @@ TEST(NVFuserTest, FusionTVSplit_CUDA) { static_cast(outer)->lhs()->sameAs( tv->getRootDomain()[2]->extent()) && static_cast(static_cast(outer)->rhs()) - ->sameAs(new Int(2))); + ->sameAs(IrBuilder::create(2))); IterDomain* inner = static_cast(tv->axis(3)); TORCH_CHECK( @@ -952,7 +1037,7 @@ TEST(NVFuserTest, FusionTVSplit_CUDA) { static_cast(inner->extent())->value().value() == 2); } -TEST(NVFuserTest, FusionTVMerge_CUDA) { +TEST_F(NVFuserTest, FusionTVMerge_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -970,7 +1055,7 @@ TEST(NVFuserTest, FusionTVMerge_CUDA) { tv->getRootDomain()[2]->extent()); } -TEST(NVFuserTest, FusionTVReorder_CUDA) { +TEST_F(NVFuserTest, FusionTVReorder_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1020,38 +1105,43 @@ TEST(NVFuserTest, FusionTVReorder_CUDA) { TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); } -TEST(NVFuserTest, FusionEquality_CUDA) { +TEST_F(NVFuserTest, FusionEquality_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* fval1 = new Double(); + Double* fval1 = IrBuilder::create(); Double* fval1_copy = fval1; - Double* fval2 = new Double(); - Double* fone = new Double(1.0); + Double* fval2 = IrBuilder::create(); + Double* fone = IrBuilder::create(1.0); TORCH_CHECK(fval1->sameAs(fval1_copy)); TORCH_CHECK(!fval1->sameAs(fval2)); TORCH_CHECK(!fone->sameAs(fval1)); - TORCH_CHECK(fone->sameAs(new Double(1.0))); + TORCH_CHECK(fone->sameAs(IrBuilder::create(1.0))); - Int* ival1 = new Int(); + Int* ival1 = IrBuilder::create(); Int* ival1_copy = ival1; - Int* ival2 = new Int(); - Int* ione = new Int(1); + Int* ival2 = IrBuilder::create(); + Int* ione = IrBuilder::create(1); TORCH_CHECK(ival1->sameAs(ival1_copy)); TORCH_CHECK(!ival1->sameAs(ival2)); TORCH_CHECK(!ione->sameAs(ival1)); - TORCH_CHECK(ione->sameAs(new Int(1))); - - BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1); - BinaryOp* add1_copy = - new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1); - BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Double(), fval1, ival1); - - UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Double(), fval1); - UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Double(), fval2); - UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Double(), fval1); + TORCH_CHECK(ione->sameAs(IrBuilder::create(1))); + + BinaryOp* add1 = IrBuilder::create( + BinaryOpType::Add, IrBuilder::create(), fval1, ival1); + BinaryOp* add1_copy = IrBuilder::create( + BinaryOpType::Add, IrBuilder::create(), fval1, ival1); + BinaryOp* sub1 = IrBuilder::create( + BinaryOpType::Sub, IrBuilder::create(), fval1, ival1); + + UnaryOp* neg1 = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval1); + UnaryOp* neg2 = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval2); + UnaryOp* neg1_copy = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval1); TORCH_CHECK(add1->sameAs(add1_copy)); TORCH_CHECK(!add1->sameAs(sub1)); @@ -1061,22 +1151,22 @@ TEST(NVFuserTest, FusionEquality_CUDA) { TORCH_CHECK(!neg1->sameAs(neg2)); } -TEST(NVFuserTest, FusionDependency_CUDA) { +TEST_F(NVFuserTest, FusionDependency_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* d0 = new Double(0.f); - Double* d1 = new Double(1.f); + Double* d0 = IrBuilder::create(0.f); + Double* d1 = IrBuilder::create(1.f); auto d2 = add(d0, d1); auto d3 = add(d2, d2); - Double* d4 = new Double(4.f); - Double* d5 = new Double(5.f); + Double* d4 = IrBuilder::create(4.f); + Double* d5 = IrBuilder::create(5.f); auto d6 = add(d4, d5); - Double* d7 = new Double(7.f); - Double* d8 = new Double(8.f); + Double* d7 = IrBuilder::create(7.f); + Double* d8 = IrBuilder::create(8.f); auto d9 = add(d7, d8); auto d10 = add(d6, d9); @@ -1131,7 +1221,7 @@ TEST(NVFuserTest, FusionDependency_CUDA) { TORCH_CHECK(dep_chain.empty()); } -TEST(NVFuserTest, FusionParser_CUDA) { +TEST_F(NVFuserTest, FusionParser_CUDA) { // This test may not pass if using a custom block sync as there may // be additional calls. Skip the test as it's not specifically // relevant with block synchronizatin. @@ -1174,31 +1264,31 @@ TEST(NVFuserTest, FusionParser_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { if ((((((((((nvfuser_index_t)blockIdx.x) * 1) + 0) * 1) + 0) * 128) + ((nvfuser_index_t)threadIdx.x)) < T0.size[0])) { - constexpr nvfuser_index_t ki183 = 0; + constexpr nvfuser_index_t i33 = 0; float T5[1]; - constexpr nvfuser_index_t ki217 = 0; - T5[ki217] = 0; - constexpr nvfuser_index_t ki208 = 0; - T5[ki208] - = T1[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki183) * 1) + ki208) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; + constexpr nvfuser_index_t i45 = 0; + T5[i45] = 0; + constexpr nvfuser_index_t i41 = 0; + T5[i41] + = T1[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i33) * 1) + i41) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; float T4[1]; - constexpr nvfuser_index_t ki223 = 0; - T4[ki223] = 0; - constexpr nvfuser_index_t ki203 = 0; - T4[ki203] - = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki183) * 1) + ki203) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; + constexpr nvfuser_index_t i47 = 0; + T4[i47] = 0; + constexpr nvfuser_index_t i39 = 0; + T4[i39] + = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i33) * 1) + i39) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; float T6[1]; - constexpr nvfuser_index_t ki192 = 0; + constexpr nvfuser_index_t i37 = 0; float T2[1]; T2[0] - = T4[ki192] - * T5[ki192]; - T6[ki192] + = T4[i37] + * T5[i37]; + T6[i37] = T2[0] - * T4[ki192]; - constexpr nvfuser_index_t ki185 = 0; - T3[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki183) * 1) + ki185) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] - = T6[ki185]; + * T4[i37]; + constexpr nvfuser_index_t i35 = 0; + T3[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i33) * 1) + i35) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] + = T6[i35]; } } )"; @@ -1227,62 +1317,25 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te } FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1, input2}, lparams); auto outputs = fe.runFusion({input1, input2}, lparams); at::Tensor output_ref = input1 * input2 * input1; TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionForLoop_CUDA) { -// TODO(kir): re-enable this test -// due to the current "GpuLower guard" approach, we can only create -// kernel IR during GpuLower::lower() -#if 0 - Fusion fusion; - FusionGuard fg(&fusion); - - const auto TV0 = new TensorView( - new TensorDomain({new IterDomain(new Int(0), new Int(16))}), - DataType::Float); - const auto TV1 = new TensorView( - new TensorDomain({new IterDomain(new Int(0), new Int(16))}), - DataType::Float); - - fusion.addInput(TV0); - fusion.addInput(TV1); - - auto ID0 = new kir::IterDomain(new IterDomain(new Int(0), new Int(8))); - - TensorView* TV2 = add(TV0, TV1); - BinaryOp* op = static_cast(TV2->definition(); - fusion.addOutput(TV2); - - auto fl = new kir::ForLoop(new kir::Int(c10::nullopt), ID0, {op}); - - std::stringstream result; - std::stringstream ref; - result << fl; - ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}"; - - if (result.str().compare(ref.str()) == 0) { - std::stringstream err_msg; - err_msg << "ForLoop printing has changed or something has gone wrong. " - << result.str() << "\n does not match reference: " << ref.str() - << std::endl; - TORCH_CHECK(false, err_msg.str()); - } -#endif -} - -TEST(NVFuserTest, FusionOuterSplit_CUDA) { +TEST_F(NVFuserTest, FusionOuterSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(3); - new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0)); - TensorView* tv1 = add(tv0, new Double(2.0)); - TensorView* tv2 = add(tv1, new Double(3.0)); + IrBuilder::create( + BinaryOpType::Add, + tv0, + IrBuilder::create(0.0), + IrBuilder::create(1.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); fusion.addOutput(tv2); //[I0, I1, I2] @@ -1312,15 +1365,19 @@ TEST(NVFuserTest, FusionOuterSplit_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionCodeGen_CUDA) { +TEST_F(NVFuserTest, FusionCodeGen_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(3); - new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0)); - TensorView* tv1 = add(tv0, new Double(2.0)); - TensorView* tv2 = add(tv1, new Double(3.0)); + IrBuilder::create( + BinaryOpType::Add, + tv0, + IrBuilder::create(0.0), + IrBuilder::create(1.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); fusion.addOutput(tv2); //[I0, I1, I2] @@ -1349,13 +1406,13 @@ TEST(NVFuserTest, FusionCodeGen_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionCodeGen2_CUDA) { +TEST_F(NVFuserTest, FusionCodeGen2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(3); TensorView* tv1 = makeSymbolicTensor(3); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addInput(tv0); @@ -1382,7 +1439,7 @@ TEST(NVFuserTest, FusionCodeGen2_CUDA) { at::Tensor input2 = at::randn_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; @@ -1391,7 +1448,7 @@ TEST(NVFuserTest, FusionCodeGen2_CUDA) { TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionSimplePWise_CUDA) { +TEST_F(NVFuserTest, FusionSimplePWise_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // dimensionality of the problem @@ -1407,7 +1464,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -1439,7 +1496,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) { at::Tensor output = at::empty_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; @@ -1448,7 +1505,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionExecKernel_CUDA) { +TEST_F(NVFuserTest, FusionExecKernel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1462,7 +1519,7 @@ TEST(NVFuserTest, FusionExecKernel_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -1490,7 +1547,7 @@ TEST(NVFuserTest, FusionExecKernel_CUDA) { at::Tensor input2 = at::ones_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); at::Tensor check = at::full({1, 128}, 4, options); @@ -1502,7 +1559,7 @@ int ceilDiv_(int a, int b) { return (a + b - 1) / b; } -TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -1517,10 +1574,10 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = add(tv1, new Double(3.0)); - TensorView* tv4 = mul(tv1, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); @@ -1538,7 +1595,8 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { tv0->computeAt(tv7, 1); - GpuLower gpulw(&fusion); + ComputeAtMap loop_map(ComputeAtMap::MappingMode::LOOP); + loop_map.build(&fusion); // The this-position of the last tensor should be zero. TORCH_CHECK( @@ -1550,11 +1608,12 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { // The position of every other tensor should be 1. for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0))); + + TORCH_CHECK(loop_map.areMapped(tv7->axis(0), tv->axis(0))); } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -1579,14 +1638,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -1600,9 +1659,9 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); @@ -1621,7 +1680,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { tv0->computeAt(tv6, 1); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1643,13 +1702,13 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { std::vector aten_outputs = {t5, t6}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 @@ -1662,7 +1721,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { TensorView* tv1 = makeSymbolicTensor(4); fusion.addInput(tv1); - TensorView* tv2 = mul(tv1, new Double(.979361)); + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); @@ -1679,7 +1738,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1700,14 +1759,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { at::Tensor cg_output = at::empty_like(t0, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 @@ -1747,7 +1806,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1769,14 +1828,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 @@ -1788,7 +1847,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -1809,14 +1868,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1824,7 +1883,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -1848,26 +1907,26 @@ TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1.0)); + auto tv1 = add(tv0, IrBuilder::create(1.0)); auto tv2 = makeSymbolicTensor(1); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(3.0)); + auto tv3 = add(tv2, IrBuilder::create(3.0)); auto tv4 = add(tv1, tv3); fusion.addOutput(tv4); @@ -1899,9 +1958,6 @@ TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { auto tv5_domain_current = tv5->domain()->domain(); TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain"); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; @@ -1919,25 +1975,27 @@ TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { std::vector aten_inputs = {t0, t2, t6}; std::vector aten_outputs = {t4, t7}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1.0)); + auto tv1 = add(tv0, IrBuilder::create(1.0)); auto tv2 = makeSymbolicTensor(1); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(3.0)); + auto tv3 = add(tv2, IrBuilder::create(3.0)); auto tv4 = add(tv1, tv3); fusion.addOutput(tv4); @@ -1964,9 +2022,6 @@ TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { tv2->computeAt(tv4, -1); tv0->computeAt(tv7, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; @@ -1984,13 +2039,15 @@ TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { std::vector aten_inputs = {t0, t2, t6}; std::vector aten_outputs = {t4, t7}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -2005,10 +2062,10 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = add(tv1, new Double(3.0)); - TensorView* tv4 = mul(tv1, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); @@ -2036,14 +2093,17 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 && tv6->getMaxProducerPosition() == 1); + ComputeAtMap loop_map(ComputeAtMap::MappingMode::LOOP); + loop_map.build(&fusion); + // The position of every other tensor should be 1. for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0))); + TORCH_CHECK(loop_map.areMapped(tv7->axis(0), tv->axis(0))); } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -2068,14 +2128,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -2089,9 +2149,9 @@ TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); @@ -2110,7 +2170,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { tv0->computeWith(tv6, 1); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -2132,13 +2192,13 @@ TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { std::vector aten_outputs = {t5, t6}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 @@ -2151,7 +2211,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { TensorView* tv1 = makeSymbolicTensor(4); fusion.addInput(tv1); - TensorView* tv2 = mul(tv1, new Double(.979361)); + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); @@ -2173,7 +2233,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -2194,14 +2254,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { at::Tensor cg_output = at::empty_like(t0, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 @@ -2240,7 +2300,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -2262,14 +2322,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 @@ -2281,7 +2341,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -2302,14 +2362,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2317,7 +2377,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -2341,14 +2401,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -2 @@ -2358,9 +2418,9 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv1, new Double(-2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); fusion.addOutput(tv2); fusion.addOutput(tv3); @@ -2387,10 +2447,12 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { TORCH_CHECK( tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1); + ComputeAtMap loop_map(ComputeAtMap::MappingMode::LOOP); + loop_map.build(&fusion); + // Note that tv2 is also computed at tv3. for (auto tv : {tv1, tv2}) { - TORCH_CHECK( - gpulw.caLoopMap().areMapped(tv->axis(0), computeAtTarget->axis(0))); + TORCH_CHECK(loop_map.areMapped(tv->axis(0), computeAtTarget->axis(0))); } TORCH_CHECK(tv3->getComputeAtPosition() == 0); @@ -2414,7 +2476,7 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -2422,7 +2484,7 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { } // Similar to ComputeAtMultiConsumers, but with a common consumer. -TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -2 @@ -2434,11 +2496,11 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv1, new Double(-2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); TensorView* tv4 = add(tv2, tv3); - TensorView* tv5 = mul(tv4, new Double(5.0)); + TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); fusion.addOutput(tv3); fusion.addOutput(tv4); fusion.addOutput(tv5); @@ -2492,14 +2554,14 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -1 @@ -2511,10 +2573,10 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv2, new Double(-1.0)); - TensorView* tv4 = add(tv1, new Double(4.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); + TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); TensorView* tv5 = add(tv3, tv4); fusion.addOutput(tv5); @@ -2541,7 +2603,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // All tensors should have the same dimenionality as the target for (Val* val : fusion.vals()) { - if (fusion.hasInput(val) || + if (val->isFusionInput() || val->getValType().value() != ValType::TensorView) { continue; } @@ -2555,7 +2617,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { } for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (!fusion.hasInput(tv)) { + if (!tv->isFusionInput()) { tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } @@ -2574,7 +2636,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { at::Tensor cg_output = at::empty_like(aten_input, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( @@ -2583,7 +2645,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // Similar to the above common consumer test but adds an additional // tensor that has no common consumer with the other tensors. -TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -1 @@ -2596,12 +2658,12 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv2, new Double(-1.0)); - TensorView* tv4 = add(tv1, new Double(4.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); + TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); TensorView* tv5 = add(tv3, tv4); - TensorView* tv6 = add(tv1, new Double(6.0)); + TensorView* tv6 = add(tv1, IrBuilder::create(6.0)); fusion.addOutput(tv5); fusion.addOutput(tv6); @@ -2627,7 +2689,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // All tensors should have the same dimenionality as the target for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (fusion.hasInput(tv)) { + if (tv->isFusionInput()) { continue; } TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); @@ -2640,7 +2702,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = val->as(); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -2664,7 +2726,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -2673,7 +2735,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor // that does not have data dependency with the consumer. -TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv1 * -2 @@ -2686,13 +2748,13 @@ TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv1, new Double(-2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); TensorView* tv4 = add(tv2, tv3); - TensorView* tv5 = mul(tv4, new Double(5.0)); + TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); // Notice that tv6 is not a consumer of tv4. - TensorView* tv6 = mul(tv1, new Double(6.0)); + TensorView* tv6 = mul(tv1, IrBuilder::create(6.0)); fusion.addOutput(tv3); fusion.addOutput(tv4); fusion.addOutput(tv5); @@ -2737,7 +2799,7 @@ TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -2822,7 +2884,7 @@ void checkIdMapped( } // namespace -TEST(NVFuserTest, FusionRootMappingBasic_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2876,7 +2938,7 @@ TEST(NVFuserTest, FusionRootMappingBasic_CUDA) { checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain()); } -TEST(NVFuserTest, FusionRootMappingRfactor_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2960,7 +3022,7 @@ TEST(NVFuserTest, FusionRootMappingRfactor_CUDA) { {true, true, false}); } -TEST(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2987,7 +3049,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { {true, false}); } -TEST(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3021,7 +3083,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); } -TEST(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3050,7 +3112,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { {true, false}); } -TEST(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3095,13 +3157,13 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { } // Reproducer of issue #749 -TEST(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {false, true}); auto tv4 = add(tv0, tv3); @@ -3153,13 +3215,13 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { } // Similar to RootMappingReductionDependency5 but with rFactor -TEST(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {false, true}); auto tv4 = add(tv0, tv3); @@ -3227,7 +3289,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { {true, true}); } -TEST(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3265,7 +3327,9 @@ TEST(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { {false, false}); } -TEST(NVFuserTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { +TEST_F( + NVFuserTest, + FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3299,7 +3363,7 @@ TEST(NVFuserTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { {false, true}); } -TEST(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3386,7 +3450,7 @@ TEST(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { {true, false}); } -TEST(NVFuserTest, FusionRootMappingBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3426,7 +3490,7 @@ TEST(NVFuserTest, FusionRootMappingBroadcast_CUDA) { } // Reproducer of issue #723 -TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3461,7 +3525,7 @@ TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto t3 = t0; @@ -3470,13 +3534,13 @@ TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = broadcast(tv1, {true, false}); auto tv3 = broadcast(tv1, {false, true}); auto tv4 = add(tv2, tv3); @@ -3486,7 +3550,7 @@ TEST(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) { ASSERT_ANY_THROW(tv1->computeAt(tv4, 1)); } -TEST(NVFuserTest, FusionScalarInputs_CUDA) { +TEST_F(NVFuserTest, FusionScalarInputs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3495,13 +3559,13 @@ TEST(NVFuserTest, FusionScalarInputs_CUDA) { TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - Double* d0 = new Double(); + Double* d0 = IrBuilder::create(); fusion.addInput(d0); - Double* d1 = new Double(); + Double* d1 = IrBuilder::create(); fusion.addInput(d1); - Double* d2 = new Double(); + Double* d2 = IrBuilder::create(); fusion.addInput(d2); - Double* d3 = new Double(); + Double* d3 = IrBuilder::create(); fusion.addInput(d3); Val* d4 = mul(d0, d1); Val* d5 = sub(d2, d3); @@ -3524,7 +3588,7 @@ TEST(NVFuserTest, FusionScalarInputs_CUDA) { tv4->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -3568,14 +3632,14 @@ TEST(NVFuserTest, FusionScalarInputs_CUDA) { at::Scalar(fl3)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionLoopUnroll_CUDA) { +TEST_F(NVFuserTest, FusionLoopUnroll_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3589,7 +3653,7 @@ TEST(NVFuserTest, FusionLoopUnroll_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -3621,7 +3685,7 @@ TEST(NVFuserTest, FusionLoopUnroll_CUDA) { at::Tensor input1 = at::randn({129, 13, 3}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}); auto outputs = fe.runFusion({input0, input1}); TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0)))); @@ -3636,11 +3700,11 @@ Val* gen_jit_operand(std::pair desc) { return makeSymbolicTensor(2, desc.second); } else if (desc.first == ValType::Scalar) { if (desc.second == DataType::Float) { - return new Double(); + return IrBuilder::create(); } else if (desc.second == DataType::Double) { - return new Double(); + return IrBuilder::create(); } else if (desc.second == DataType::Int) { - return new Int(); + return IrBuilder::create(); } else { TORCH_CHECK(false, "Not currently supported type: ", desc.first); } @@ -3763,7 +3827,7 @@ void test_op( at::manual_seed(0); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs_ivalues); fe.runFusion(aten_inputs_ivalues, output_vect); cudaDeviceSynchronize(); @@ -3809,7 +3873,7 @@ void test_op( std::make_index_sequence{}); } -TEST(NVFuserTest, FusionUnaryOps_CUDA) { +TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { using OpTuple = std::tuple; @@ -3833,7 +3897,6 @@ TEST(NVFuserTest, FusionUnaryOps_CUDA) { OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"}, OpTuple{at::floor, UnaryOpType::Floor, "floor"}, OpTuple{at::frac, UnaryOpType::Frac, "frac"}, - // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"}, OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"}, OpTuple{at::log, UnaryOpType::Log, "log"}, OpTuple{at::log10, UnaryOpType::Log10, "log10"}, @@ -3904,7 +3967,7 @@ TEST(NVFuserTest, FusionUnaryOps_CUDA) { } } -TEST(NVFuserTest, FusionBinaryOps_CUDA) { +TEST_F(NVFuserTest, FusionBinaryOps_CUDA) { using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&); using OpTuple = std::tuple; @@ -4009,7 +4072,7 @@ TEST(NVFuserTest, FusionBinaryOps_CUDA) { } } -TEST(NVFuserTest, FusionTernaryOps_CUDA) { +TEST_F(NVFuserTest, FusionTernaryOps_CUDA) { std::vector dtypes = {DataType::Double, DataType::Float}; for (auto dtype : dtypes) { @@ -4024,9 +4087,15 @@ TEST(NVFuserTest, FusionTernaryOps_CUDA) { /*JIT Func */ [&](Val* in1) -> Val* { if (dtype == DataType::Float) { - return clamp(in1, new Double(0.f), new Double(1.f)); + return clamp( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } else { - return clamp(in1, new Double(0.f), new Double(1.f)); + return clamp( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } }, /*Output */ std::make_pair(ValType::TensorView, dtype), @@ -4043,9 +4112,15 @@ TEST(NVFuserTest, FusionTernaryOps_CUDA) { /*JIT Func */ [&](Val* in1) -> Val* { if (dtype == DataType::Float) { - return threshold(in1, new Double(0.f), new Double(1.f)); + return threshold( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } else { - return threshold(in1, new Double(0.f), new Double(1.f)); + return threshold( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } }, /*Output */ std::make_pair(ValType::TensorView, dtype), @@ -4070,7 +4145,7 @@ TEST(NVFuserTest, FusionTernaryOps_CUDA) { } } -TEST(NVFuserTest, FusionCompoundOps_CUDA) { +TEST_F(NVFuserTest, FusionCompoundOps_CUDA) { std::vector dtypes = {DataType::Double, DataType::Float}; for (auto dtype : dtypes) { @@ -4114,7 +4189,7 @@ TEST(NVFuserTest, FusionCompoundOps_CUDA) { } } -TEST(NVFuserTest, FusionCastOps_CUDA) { +TEST_F(NVFuserTest, FusionCastOps_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4139,7 +4214,7 @@ TEST(NVFuserTest, FusionCastOps_CUDA) { const at::ArrayRef input_ivalues(inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, input_ivalues); auto outputs = fe.runFusion(input_ivalues); ref_output = at::_cast_Half(at::_cast_Double(input1)); @@ -4156,7 +4231,7 @@ TEST(NVFuserTest, FusionCastOps_CUDA) { // Start off simple, block on the outer dim // block stride + thread all reduce + unrolling on inner dim -TEST(NVFuserTest, FusionReduction1_CUDA) { +TEST_F(NVFuserTest, FusionReduction1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4165,10 +4240,13 @@ TEST(NVFuserTest, FusionReduction1_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, 128); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -4207,7 +4285,7 @@ TEST(NVFuserTest, FusionReduction1_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -4216,7 +4294,7 @@ TEST(NVFuserTest, FusionReduction1_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction2_CUDA) { +TEST_F(NVFuserTest, FusionReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4225,7 +4303,8 @@ TEST(NVFuserTest, FusionReduction2_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4278,14 +4357,14 @@ TEST(NVFuserTest, FusionReduction2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1}); testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction3_CUDA) { +TEST_F(NVFuserTest, FusionReduction3_CUDA) { // What if Z participates in the reduction with X? Fusion fusion; FusionGuard fg(&fusion); @@ -4295,7 +4374,8 @@ TEST(NVFuserTest, FusionReduction3_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4328,7 +4408,7 @@ TEST(NVFuserTest, FusionReduction3_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); auto aten_output = aten_input.to(at::kDouble).sum({1}); @@ -4337,7 +4417,7 @@ TEST(NVFuserTest, FusionReduction3_CUDA) { &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction4_CUDA) { +TEST_F(NVFuserTest, FusionReduction4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4351,7 +4431,8 @@ TEST(NVFuserTest, FusionReduction4_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2); + TensorView* tv3 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); // tv3[I0, R1] = tv2[I0, I1] TensorView* tv4 = makeSymbolicTensor(1); @@ -4393,7 +4474,7 @@ TEST(NVFuserTest, FusionReduction4_CUDA) { at::Tensor t4 = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1, t4}); auto cg_outputs = fe.runFusion({t0, t1, t4}); auto t2 = t0.add(t1); @@ -4404,7 +4485,7 @@ TEST(NVFuserTest, FusionReduction4_CUDA) { &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction5_CUDA) { +TEST_F(NVFuserTest, FusionReduction5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4413,7 +4494,8 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4431,7 +4513,7 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { tv1->axis(0)->parallelize(ParallelType::BIDy); for (auto* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { val->as()->axis(-1)->parallelize(ParallelType::TIDx); } @@ -4446,7 +4528,7 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { at::Tensor cg_output = at::empty({bidy, tidx}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -4454,7 +4536,7 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction6_CUDA) { +TEST_F(NVFuserTest, FusionReduction6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4466,10 +4548,13 @@ TEST(NVFuserTest, FusionReduction6_CUDA) { fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(2, bdimx); // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] @@ -4508,14 +4593,14 @@ TEST(NVFuserTest, FusionReduction6_CUDA) { at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1, 2}); testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultiGridReduction_CUDA) { +TEST_F(NVFuserTest, FusionMultiGridReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4540,7 +4625,7 @@ TEST(NVFuserTest, FusionMultiGridReduction_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); std::vector aten_outputs = { @@ -4548,7 +4633,7 @@ TEST(NVFuserTest, FusionMultiGridReduction_CUDA) { testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultiGridReduction2_CUDA) { +TEST_F(NVFuserTest, FusionMultiGridReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4566,7 +4651,7 @@ TEST(NVFuserTest, FusionMultiGridReduction2_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionReductionTFT_CUDA) { +TEST_F(NVFuserTest, FusionReductionTFT_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4575,7 +4660,8 @@ TEST(NVFuserTest, FusionReductionTFT_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4613,7 +4699,7 @@ TEST(NVFuserTest, FusionReductionTFT_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -4621,7 +4707,7 @@ TEST(NVFuserTest, FusionReductionTFT_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { +TEST_F(NVFuserTest, FusionReductionOuterSplit_CUDA) { // based off FusionReduction4 Fusion fusion; FusionGuard fg(&fusion); @@ -4636,7 +4722,8 @@ TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2); + TensorView* tv3 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); // tv3[I0, R1] = tv2[I0, I1] TensorView* tv4 = makeSymbolicTensor(1); @@ -4676,7 +4763,7 @@ TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { at::Tensor t4 = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1, t4}); auto cg_outputs = fe.runFusion({t0, t1, t4}); auto t2 = t0.add(t1); @@ -4687,7 +4774,7 @@ TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBranches_CUDA) { +TEST_F(NVFuserTest, FusionBranches_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4699,7 +4786,7 @@ TEST(NVFuserTest, FusionBranches_CUDA) { fusion.addInput(tv1); fusion.addInput(tv2); - auto tv3 = add(tv0, new Double(1.0)); + auto tv3 = add(tv0, IrBuilder::create(1.0)); auto tv4 = add(tv3, tv1); auto tv5 = add(tv3, tv2); auto tv6 = add(tv4, tv5); @@ -4735,7 +4822,7 @@ TEST(NVFuserTest, FusionBranches_CUDA) { std::vector aten_inputs = {t0, t1, t2}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t3 = t0.add(1.0); @@ -4747,14 +4834,14 @@ TEST(NVFuserTest, FusionBranches_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast1_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1.5)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.5)); TensorView* tv2 = makeSymbolicTensor(2); fusion.addInput(tv2); @@ -4797,14 +4884,14 @@ TEST(NVFuserTest, FusionSimpleBCast1_CUDA) { std::vector aten_inputs = {t0, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast2_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4821,7 +4908,7 @@ TEST(NVFuserTest, FusionSimpleBCast2_CUDA) { TensorView* tv4 = makeSymbolicTensor(2); fusion.addInput(tv4); - TensorView* tv5 = sub(tv4, new Double(0.1)); + TensorView* tv5 = sub(tv4, IrBuilder::create(0.1)); TensorView* tv6 = broadcast(tv5, {true, false, false}); @@ -4856,28 +4943,30 @@ TEST(NVFuserTest, FusionSimpleBCast2_CUDA) { std::vector aten_inputs = {t0, t1, t4}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast3_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views std::vector dom; - dom.push_back(new IterDomain(new Int(0), new Int())); - dom.push_back(new IterDomain( - new Int(0), - new Int(1), + dom.push_back(IrBuilder::create( + IrBuilder::create(0), IrBuilder::create())); + dom.push_back(IrBuilder::create( + IrBuilder::create(0), + IrBuilder::create(1), ParallelType::Serial, IterType::BroadcastWithStride)); // tv0[I1, B{1}] - TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create(dom), DataType::Float); fusion.addInput(tv0); // tv1[I0, I1, I2] @@ -4908,26 +4997,28 @@ TEST(NVFuserTest, FusionSimpleBCast3_CUDA) { at::Tensor cg_output = at::empty({x, y, z}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast4_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views std::vector dom; - dom.push_back(new IterDomain( - new Int(0), - new Int(1), + dom.push_back(IrBuilder::create( + IrBuilder::create(0), + IrBuilder::create(1), ParallelType::Serial, IterType::BroadcastWithStride)); - dom.push_back(new IterDomain(new Int(0), new Int())); - TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); + dom.push_back(IrBuilder::create( + IrBuilder::create(0), IrBuilder::create())); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create(dom), DataType::Float); TensorView* tv1 = makeSymbolicTensor(3); fusion.addInput(tv0); @@ -4963,30 +5054,35 @@ TEST(NVFuserTest, FusionSimpleBCast4_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast5_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); constexpr int m = 2, k = 3, n = 4; - auto zero = new Int(0); - auto M = new IterDomain(zero, new Int(m)); - auto K = new IterDomain(zero, new Int(k)); - auto N = new IterDomain(zero, new Int(n)); + auto zero = IrBuilder::create(0); + auto M = IrBuilder::create(zero, IrBuilder::create(m)); + auto K = IrBuilder::create(zero, IrBuilder::create(k)); + auto N = IrBuilder::create(zero, IrBuilder::create(n)); // Set up your input tensor views - TensorView* tv0 = - new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create( + std::vector({M, K}), std::vector({true, true})), + DataType::Float); // Note: IterDomain must not be reused, so K needs to be cloned. - TensorView* tv1 = new TensorView( - new TensorDomain({K->clone(), N}, {true, true}), DataType::Float); + TensorView* tv1 = IrBuilder::create( + IrBuilder::create( + std::vector({K->clone(), N}), + std::vector({true, true})), + DataType::Float); fusion.addInput(tv0); fusion.addInput(tv1); @@ -5018,21 +5114,21 @@ TEST(NVFuserTest, FusionSimpleBCast5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComplexBCast1_CUDA) { +TEST_F(NVFuserTest, FusionComplexBCast1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); int x = 2, y = 3, z = 4; auto tv0 = makeConcreteTensor({y}); - auto tv1 = div(tv0, new Double(2.0)); + auto tv1 = div(tv0, IrBuilder::create(2.0)); auto tv2 = broadcast(tv1, {false, true}); auto tv3 = makeConcreteTensor({y, z}); auto tv4 = mul(tv2, tv3); @@ -5074,21 +5170,21 @@ TEST(NVFuserTest, FusionComplexBCast1_CUDA) { std::vector aten_inputs = {t0, t3, t6}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComplexBCast2_CUDA) { +TEST_F(NVFuserTest, FusionComplexBCast2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); int x = 2, y = 3, z = 4; auto tv0 = makeConcreteTensor({y, z}); - auto tv1 = div(tv0, new Double(2.0)); + auto tv1 = div(tv0, IrBuilder::create(2.0)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = makeConcreteTensor({x, y}); @@ -5119,7 +5215,7 @@ TEST(NVFuserTest, FusionComplexBCast2_CUDA) { at::Tensor t4 = at::randn({x, y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t4}); auto cg_outputs = fe.runFusion({t0, t4}); auto t1 = t0.div(2.0); @@ -5131,7 +5227,7 @@ TEST(NVFuserTest, FusionComplexBCast2_CUDA) { &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5143,7 +5239,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1.0)); + auto tv2 = add(tv0, IrBuilder::create(1.0)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); @@ -5178,14 +5274,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { std::vector aten_inputs = {t0, t1}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5197,7 +5293,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1.0)); + auto tv2 = add(tv0, IrBuilder::create(1.0)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); @@ -5232,14 +5328,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { std::vector aten_inputs = {t0, t1}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5250,7 +5346,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1.0)); + auto tv2 = add(tv0, IrBuilder::create(1.0)); auto tv3 = add(tv2, tv1); fusion.addOutput(tv3); @@ -5266,14 +5362,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5283,7 +5379,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { TensorView* tv1 = makeConcreteTensor({4, 4, 8}); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = broadcast(tv2, {true, false, false}); TensorView* tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -5298,14 +5394,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5315,7 +5411,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) { TensorView* tv1 = makeSymbolicTensor(3); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = broadcast(tv2, {true, false, true}); TensorView* tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -5336,14 +5432,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5371,7 +5467,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) { scheduleReduction(&fusion, reduction_params.value()); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}, reduction_params.value().lparams); auto cg_outputs = fe.runFusion({input0, input1}, reduction_params.value().lparams); @@ -5388,7 +5484,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) { reduction_params.value().lparams); } -TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing7_CUDA) { // Might be able to use this one without 6 as the heuristics in 6 may change // and this test is to cover the same issue. Fusion fusion; @@ -5417,15 +5513,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) { tv4->axis(0)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto at_t0 = at::randn({numel_x}, options); auto at_t1 = at::randn({numel_x, numel_y}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t1}); auto cg_outputs = fe.runFusion({at_t0, at_t1}); auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) @@ -5436,7 +5531,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing8_CUDA) { // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -5464,15 +5559,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) { tv4->axis(0)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto at_t0 = at::randn({numel_x}, options); auto at_t1 = at::randn({numel_x, numel_y}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t1}); auto cg_outputs = fe.runFusion({at_t0, at_t1}); auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) @@ -5483,7 +5577,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing9_CUDA) { // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -5493,7 +5587,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = mul(tv1, new Double(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); auto tv3 = makeSymbolicTensor(3); @@ -5513,7 +5607,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); auto at_t1 = at_t0.unsqueeze(-1); @@ -5525,7 +5619,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing10_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5539,7 +5633,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -5575,7 +5669,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { at::Tensor output = at::empty_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; @@ -5584,7 +5678,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing11_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5596,7 +5690,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv1, new Double(1.0)); + auto tv2 = add(tv1, IrBuilder::create(1.0)); auto tv3 = broadcast(tv2, {true, false, true, true}); auto tv4 = add(tv3, tv0); @@ -5631,7 +5725,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { std::vector aten_inputs = {t0, t1}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5639,16 +5733,16 @@ TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { } // Intended to stress the lowering of our code generator -TEST(NVFuserTest, FusionAdvancedLowering1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeConcreteTensor({9, 5}); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv1, new Double(3)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); TensorView* tv4 = sum(tv3, {1}); fusion.addOutput(tv2); @@ -5671,15 +5765,14 @@ TEST(NVFuserTest, FusionAdvancedLowering1_CUDA) { std::vector aten_outputs = {t2, t4}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5691,7 +5784,7 @@ TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { TensorView* tv2 = makeSymbolicTensor(3); fusion.addInput(tv2); - TensorView* tv3 = add(tv0, new Double(1)); + TensorView* tv3 = add(tv0, IrBuilder::create(1)); TensorView* tv4 = broadcast(tv3, {false, true}); TensorView* tv5 = add(tv4, tv1); TensorView* tv6 = add(tv5, tv2); @@ -5727,8 +5820,7 @@ TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { std::vector aten_outputs = {t6}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5736,7 +5828,7 @@ TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { } // TODO: Complete test -TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5746,13 +5838,13 @@ TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { fusion.addInput(tv1); // [b0, i1] - auto tv2 = add(tv0, new Double(2.0)); + auto tv2 = add(tv0, IrBuilder::create(2.0)); // [i0, i1] - auto tv3 = add(tv1, new Double(3.0)); + auto tv3 = add(tv1, IrBuilder::create(3.0)); // [b0, i1] - auto tv4 = add(tv2, new Double(4.0)); + auto tv4 = add(tv2, IrBuilder::create(4.0)); // [io, i1] auto tv5 = add(tv2, tv3); @@ -5776,8 +5868,7 @@ TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { std::vector aten_outputs = {t4, t5}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5787,7 +5878,7 @@ TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { // This excercises indexing with broadcast root axes. Non-broadcast // axes need to be preferred when propagating index exprs to root // axes. See, e.g., Index::getConsumerIndex_impl. -TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5804,9 +5895,6 @@ TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { tv4->split(0, 8); tv0->computeAt(tv4, 1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); const int bx = 10; const int by = 20; @@ -5815,6 +5903,8 @@ TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { at::Tensor t3 = at::randn({bx, by, bz}, options); std::vector aten_inputs = {t0, t3}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = @@ -5824,7 +5914,7 @@ TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedLowering5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5854,15 +5944,14 @@ TEST(NVFuserTest, FusionAdvancedLowering5_CUDA) { std::vector aten_outputs = {t3}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedLowering6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5902,8 +5991,7 @@ TEST(NVFuserTest, FusionAdvancedLowering6_CUDA) { std::vector aten_outputs = {t5, t7}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5911,7 +5999,7 @@ TEST(NVFuserTest, FusionAdvancedLowering6_CUDA) { } // Test a simple Gemm but also play around with fusion executor features -TEST(NVFuserTest, FusionSimpleGemm_CUDA) { +TEST_F(NVFuserTest, FusionSimpleGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5978,7 +6066,7 @@ TEST(NVFuserTest, FusionSimpleGemm_CUDA) { at::Tensor t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); // Lets specify a few bounds in launch params to make sure it works fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); @@ -5996,7 +6084,7 @@ TEST(NVFuserTest, FusionSimpleGemm_CUDA) { } // Softmax with a 1D tensor. Parallelized only with a single thread block. -TEST(NVFuserTest, FusionSoftmax1D_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax1D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6042,7 +6130,7 @@ TEST(NVFuserTest, FusionSoftmax1D_CUDA) { at::Tensor t3_output = at::empty_like(cg_output, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); fe.runFusion({t0}, {cg_output}); auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false); @@ -6051,7 +6139,7 @@ TEST(NVFuserTest, FusionSoftmax1D_CUDA) { } // Softmax with a 1D tensor with input normalization. -TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6063,8 +6151,8 @@ TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { fusion.addInput(input_tv0); // Normalize with the max value before computing exp. - TensorView* max_val_tv1 = - reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0); + TensorView* max_val_tv1 = reductionOp( + BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true}); TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); @@ -6111,7 +6199,7 @@ TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { at::Tensor t3_output = at::empty({dimx}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); @@ -6121,7 +6209,7 @@ TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { // Softmax with a 3D tensor, where the inner-most 3rd dimension is // normalized. Pallelized with multiple thread blocks. -TEST(NVFuserTest, FusionSoftmax3D_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax3D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6171,7 +6259,7 @@ TEST(NVFuserTest, FusionSoftmax3D_CUDA) { at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); @@ -6181,7 +6269,7 @@ TEST(NVFuserTest, FusionSoftmax3D_CUDA) { } // Softmax with a 3D tensor with input normalization. -TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6195,8 +6283,8 @@ TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { fusion.addInput(input_tv0); // Normalize with the max value before computing exp. - TensorView* max_val_tv1 = - reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0); + TensorView* max_val_tv1 = reductionOp( + BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true}); TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); @@ -6246,7 +6334,7 @@ TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); @@ -6254,7 +6342,7 @@ TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { +TEST_F(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6265,7 +6353,7 @@ TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { auto tv1 = sum(tv0, {1}); auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv0, new Double(1.0)); + auto tv3 = add(tv0, IrBuilder::create(1.0)); auto tv4 = mul(tv2, tv3); @@ -6280,10 +6368,7 @@ TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { } // Similar to FusionReduction but uses grid reduction -TEST(NVFuserTest, FusionGridReduction1_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridReduction1_CUDA) { const int gdimx = 32; const int bdimx = 128; @@ -6295,10 +6380,13 @@ TEST(NVFuserTest, FusionGridReduction1_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -6331,7 +6419,7 @@ TEST(NVFuserTest, FusionGridReduction1_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6341,10 +6429,7 @@ TEST(NVFuserTest, FusionGridReduction1_CUDA) { } // Same test as the above but uses BIDy and TIDx for reduction -TEST(NVFuserTest, FusionGridReduction2_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridReduction2_CUDA) { const int gdimy = 32; const int bdimx = 128; @@ -6356,10 +6441,13 @@ TEST(NVFuserTest, FusionGridReduction2_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -6391,7 +6479,7 @@ TEST(NVFuserTest, FusionGridReduction2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6400,7 +6488,7 @@ TEST(NVFuserTest, FusionGridReduction2_CUDA) { } // Same test but uses BIDy and BIDz for reduction. No TID used. -TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction3dim1_CUDA) { // Grid reductions when there aren't any threads are serial reductions // keep these numbers low so our error isn't too high compared to normal cuda // reductions @@ -6415,10 +6503,13 @@ TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, gdimy); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -6450,7 +6541,7 @@ TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6459,7 +6550,7 @@ TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { } // Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0 -TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction3dim0_CUDA) { // Grid reductions when there aren't any threads are serial reductions // keep these numbers low so our error isn't too high compared to normal cuda // reductions @@ -6474,10 +6565,13 @@ TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { fusion.addInput(tv0); // tv1[R0, I1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {0}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {0}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(0, gdimy); // tv1[R0o, R0i{128}, I1] = tv0[I0, I1] @@ -6507,7 +6601,7 @@ TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({0}); @@ -6516,7 +6610,7 @@ TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { } // This is similar to the FusionReduction, but swaps BIDx and TIDx -TEST(NVFuserTest, FusionGridReduction4_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6528,10 +6622,13 @@ TEST(NVFuserTest, FusionGridReduction4_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, gdimx); // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1] @@ -6570,7 +6667,7 @@ TEST(NVFuserTest, FusionGridReduction4_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6580,7 +6677,7 @@ TEST(NVFuserTest, FusionGridReduction4_CUDA) { // Grid reduction with 2D thread blocks but only TIDx and BIDx are // mapped to a reduction dim -TEST(NVFuserTest, FusionGridReduction5_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6593,10 +6690,13 @@ TEST(NVFuserTest, FusionGridReduction5_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{64}] = tv0[I0, I1] @@ -6624,7 +6724,7 @@ TEST(NVFuserTest, FusionGridReduction5_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6632,7 +6732,7 @@ TEST(NVFuserTest, FusionGridReduction5_CUDA) { } // Similar to FusionGridReduction1 but with 3D tensors -TEST(NVFuserTest, FusionGridReduction6_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6641,10 +6741,13 @@ TEST(NVFuserTest, FusionGridReduction6_CUDA) { fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); // Splitting for TID tv1->split(2, 128); @@ -6686,7 +6789,7 @@ TEST(NVFuserTest, FusionGridReduction6_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1, 2}); @@ -6696,7 +6799,7 @@ TEST(NVFuserTest, FusionGridReduction6_CUDA) { } // See issue #1049 -TEST(NVFuserTest, FusionGridReduction7_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6718,7 +6821,7 @@ TEST(NVFuserTest, FusionGridReduction7_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}); @@ -6726,7 +6829,7 @@ TEST(NVFuserTest, FusionGridReduction7_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReduction8_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction8_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6746,7 +6849,7 @@ TEST(NVFuserTest, FusionGridReduction8_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}); @@ -6754,10 +6857,7 @@ TEST(NVFuserTest, FusionGridReduction8_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReduction9_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridReduction9_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6788,7 +6888,7 @@ TEST(NVFuserTest, FusionGridReduction9_CUDA) { at::ArrayRef aten_inputs = {t0, t2}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_output = fe.runFusion(aten_inputs); auto aten_output = t0.sum({1}).add(t2); @@ -6796,7 +6896,7 @@ TEST(NVFuserTest, FusionGridReduction9_CUDA) { testValidate(&fusion, cg_output, {t0, t2}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReduction10_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction10_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6831,7 +6931,7 @@ TEST(NVFuserTest, FusionGridReduction10_CUDA) { at::Tensor t0 = at::randn({numel_w, numel_x, numel_y, numel_z}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_output = fe.runFusion({t0}); auto aten_output = t0.sum({1, 2, 3}); @@ -6839,7 +6939,7 @@ TEST(NVFuserTest, FusionGridReduction10_CUDA) { testValidate(&fusion, cg_output, {t0}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { +TEST_F(NVFuserTest, FusionNonRedAxisBind_CUDA) { int bid_x = 3; int tid_x = 2; int red_dim = 0; @@ -6851,8 +6951,8 @@ TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); tv1->split(-1, tid_x); @@ -6863,7 +6963,7 @@ TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { at::Tensor input = at::randn({16, bid_x * tid_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({red_dim}); @@ -6871,7 +6971,7 @@ TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSplitBCast_CUDA) { +TEST_F(NVFuserTest, FusionSplitBCast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6881,8 +6981,8 @@ TEST(NVFuserTest, FusionSplitBCast_CUDA) { fusion.addInput(input_tv0); fusion.addInput(input_tv1); - TensorView* sum_tv2 = - reductionOp(BinaryOpType::Add, {2}, new Double(0), input_tv0); + TensorView* sum_tv2 = reductionOp( + BinaryOpType::Add, {2}, IrBuilder::create(0), input_tv0); TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true}); TensorView* output_tv4 = div(input_tv1, bcast_tv3); @@ -6915,11 +7015,11 @@ TEST(NVFuserTest, FusionSplitBCast_CUDA) { at::Tensor cg_output = at::empty({32, 32, 128}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}); fe.runFusion({t0, t1}, {cg_output}); } -TEST(NVFuserTest, FusionBCastInnerDim_CUDA) { +TEST_F(NVFuserTest, FusionBCastInnerDim_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6933,7 +7033,7 @@ TEST(NVFuserTest, FusionBCastInnerDim_CUDA) { TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast()); } -TEST(NVFuserTest, FusionBCastReduce_CUDA) { +TEST_F(NVFuserTest, FusionBCastReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6949,14 +7049,16 @@ TEST(NVFuserTest, FusionBCastReduce_CUDA) { // Multiple consumer reduction with computeAt // https://github.com/csarofeen/pytorch/issues/110 -TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) { +TEST_F(NVFuserTest, FusionReductionMultiConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); auto tv1 = unaryOp(UnaryOpType::Exp, tv0); - auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Double(0), tv1); - auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Double(0), tv1); + auto tv2 = + reductionOp(BinaryOpType::Max, {-1}, IrBuilder::create(0), tv1); + auto tv3 = + reductionOp(BinaryOpType::Min, {-1}, IrBuilder::create(0), tv1); auto tv4 = add(tv2, tv3); fusion.addOutput(tv4); tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort); @@ -6964,7 +7066,7 @@ TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) { TORCH_CHECK(tv1->getComputeAtPosition() == 2); } -TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { for (const auto i : c10::irange(2)) { Fusion fusion; FusionGuard fg(&fusion); @@ -6973,8 +7075,8 @@ TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = add(tv1, tv2); // Set outputs tv2 or tv1 and then tv3 if (i == 0) { @@ -6996,7 +7098,7 @@ TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { aten_input + 1, (aten_input + 1) * 2}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -7004,7 +7106,7 @@ TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { } } -TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7012,8 +7114,8 @@ TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = add(tv1, tv2); fusion.addOutput(tv3); @@ -7029,14 +7131,14 @@ TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { at::Tensor cg_output = at::empty_like(aten_input, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7045,10 +7147,10 @@ TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { TensorView* tv0 = makeConcreteTensor({dimx, dimy}); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv2, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv2, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = mul(tv2, tv4); fusion.addOutput(tv5); @@ -7065,14 +7167,14 @@ TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { auto aten_output = t2.mul(t4); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { +TEST_F(NVFuserTest, FusionZeroDimComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7080,7 +7182,7 @@ TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { fusion.addInput(tv0); auto tv1 = sum(tv0, {0}); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); TORCH_CHECK(tv2->nDims() == 0); tv1->computeAt(tv2, 0); @@ -7090,14 +7192,14 @@ TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { auto aten_output = aten_input.to(at::kDouble).sum() + 1; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionZeroDimBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7130,14 +7232,14 @@ TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) { at::Tensor cg_output = at::empty({}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroDimReduction_CUDA) { +TEST_F(NVFuserTest, FusionZeroDimReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7166,14 +7268,14 @@ TEST(NVFuserTest, FusionZeroDimReduction_CUDA) { at::Tensor cg_output = at::empty({}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) { +TEST_F(NVFuserTest, FusionBCastAfterReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 128; @@ -7218,14 +7320,14 @@ TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) { std::vector aten_inputs = {t0, t4}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t4}); auto cg_outputs = fe.runFusion({t0, t4}); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionOutputBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionOutputBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7243,15 +7345,14 @@ TEST(NVFuserTest, FusionOutputBroadcast_CUDA) { auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { +TEST_F(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7270,15 +7371,14 @@ TEST(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { aten_input.to(at::kDouble).sum({0, 2, -1}, /*keepdim=*/true); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { +TEST_F(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -7291,7 +7391,11 @@ TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { fusion.addInput(tv0); TensorView* tv1 = reductionOp( - BinaryOpType::Add, {red_dim}, new Double(0), tv0, /*keep_dim=*/true); + BinaryOpType::Add, + {red_dim}, + IrBuilder::create(0), + tv0, + /*keep_dim=*/true); fusion.addOutput(tv1); @@ -7307,11 +7411,10 @@ TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); scheduleReduction(&fusion, reduction_params.value()); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto lparams = reduction_params.value().lparams; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7325,7 +7428,7 @@ TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { lparams); } -TEST(NVFuserTest, FusionSumTo_CUDA) { +TEST_F(NVFuserTest, FusionSumTo_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7340,7 +7443,7 @@ TEST(NVFuserTest, FusionSumTo_CUDA) { sum_to_shape.begin(), sum_to_shape.end(), std::back_inserter(sum_to_symb), - [](int s) -> Int* { return new Int(s); }); + [](int s) -> Int* { return IrBuilder::create(s); }); TensorView* tv0 = makeConcreteTensor(tensor_shape); fusion.addInput(tv0); @@ -7355,8 +7458,7 @@ TEST(NVFuserTest, FusionSumTo_CUDA) { auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); TORCH_CHECK( @@ -7367,7 +7469,7 @@ TEST(NVFuserTest, FusionSumTo_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSumToNoop_CUDA) { +TEST_F(NVFuserTest, FusionSumToNoop_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7382,7 +7484,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { sum_to_shape.begin(), sum_to_shape.end(), std::back_inserter(sum_to_symb), - [](int s) -> Int* { return new Int(s); }); + [](int s) -> Int* { return IrBuilder::create(s); }); TensorView* tv0 = makeConcreteTensor(tensor_shape); fusion.addInput(tv0); @@ -7390,7 +7492,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { TensorView* tv1 = sum_to(tv0, sum_to_symb); // Dummy operator to avoid tv0 both input and output - TensorView* tv2 = add(tv1, new Double(0)); + TensorView* tv2 = add(tv1, IrBuilder::create(0)); fusion.addOutput(tv2); const auto options = @@ -7399,8 +7501,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { at::Tensor aten_input = at::randn(tensor_shape_ref, options); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); @@ -7412,7 +7513,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionScheduler_CUDA) { +TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -7424,8 +7525,8 @@ TEST(NVFuserTest, FusionReductionScheduler_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -7442,7 +7543,7 @@ TEST(NVFuserTest, FusionReductionScheduler_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); @@ -7458,7 +7559,7 @@ TEST(NVFuserTest, FusionReductionScheduler_CUDA) { } // Simple reduction parallelized on a symbolic size. -TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { +TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7467,7 +7568,8 @@ TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); // Interface should just be a direct split with a Parallel type. We can @@ -7501,7 +7603,7 @@ TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7515,7 +7617,7 @@ TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { lparams); } -TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { const std::vector red_dims = {0, 2}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions @@ -7530,8 +7632,8 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -7547,7 +7649,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); fe.runFusion({aten_input}, {cg_output}, lparams); testValidate( @@ -7561,7 +7663,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { lparams); } -TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { const std::vector red_dims = {1, 3}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions @@ -7575,8 +7677,8 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -7590,7 +7692,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7604,7 +7706,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { lparams); } -TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { std::vector dtypes = { DataType::Double, DataType::Float, DataType::Half}; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 @@ -7661,8 +7763,7 @@ TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7678,7 +7779,7 @@ TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { } } -TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { std::vector dtypes = { DataType::Double, DataType::Float, DataType::Half}; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 @@ -7740,8 +7841,7 @@ TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({axis}); testValidate( @@ -7759,14 +7859,14 @@ TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { } } -TEST(NVFuserTest, FusionCacheBefore_CUDA) { +TEST_F(NVFuserTest, FusionCacheBefore_CUDA) { // TVM Cache Write Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, new Double(1.0)); - TensorView* tv2 = mul(tv1, new Double(3.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -7790,21 +7890,21 @@ TEST(NVFuserTest, FusionCacheBefore_CUDA) { at::Tensor aten_output = (aten_input + 1.0) * 3.0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheAfter_CUDA) { +TEST_F(NVFuserTest, FusionCacheAfter_CUDA) { // TVM Cache Read Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, new Double(1.0)); - TensorView* tv2 = mul(tv1, new Double(3.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -7828,20 +7928,20 @@ TEST(NVFuserTest, FusionCacheAfter_CUDA) { at::Tensor aten_output = (aten_input + 1.0) * 3.0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheFork_CUDA) { +TEST_F(NVFuserTest, FusionCacheFork_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, new Double(1.0)); - TensorView* tv2 = mul(tv1, new Double(3.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); fusion.addInput(tv0); fusion.addOutput(tv1); fusion.addOutput(tv2); @@ -7873,7 +7973,7 @@ TEST(NVFuserTest, FusionCacheFork_CUDA) { at::Tensor aten_output2 = aten_output1 * 3.0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -7885,7 +7985,7 @@ TEST(NVFuserTest, FusionCacheFork_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionCacheIndirect_CUDA) { +TEST_F(NVFuserTest, FusionCacheIndirect_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7927,14 +8027,14 @@ TEST(NVFuserTest, FusionCacheIndirect_CUDA) { at::Tensor aten_output = (t1 + (t2 - t3)) - t0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheBcast_CUDA) { +TEST_F(NVFuserTest, FusionCacheBcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7986,22 +8086,22 @@ TEST(NVFuserTest, FusionCacheBcast_CUDA) { t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0)); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { +TEST_F(NVFuserTest, FusionCacheMultiConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv0, new Double(1)); - TensorView* tv4 = add(tv3, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(1)); + TensorView* tv4 = add(tv3, IrBuilder::create(2)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -8025,7 +8125,7 @@ TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { auto aten_output = (aten_input + 1) + 2; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -8037,7 +8137,7 @@ TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionSmem_CUDA) { +TEST_F(NVFuserTest, FusionSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8084,7 +8184,7 @@ TEST(NVFuserTest, FusionSmem_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); testValidate( @@ -8093,7 +8193,7 @@ TEST(NVFuserTest, FusionSmem_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemReduce_CUDA) { +TEST_F(NVFuserTest, FusionSmemReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8133,7 +8233,7 @@ TEST(NVFuserTest, FusionSmemReduce_CUDA) { at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1}); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -8141,7 +8241,7 @@ TEST(NVFuserTest, FusionSmemReduce_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { +TEST_F(NVFuserTest, FusionSmemBlockGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8202,7 +8302,7 @@ TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); testValidate( @@ -8211,7 +8311,7 @@ TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { +TEST_F(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8291,7 +8391,7 @@ TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -8300,7 +8400,7 @@ TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8309,7 +8409,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { TensorView* max_val = reductionOp( BinaryOpType::Max, {-1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), x); // (M) TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) TensorView* x_max_sub = sub(x, bcast_max); // (M, N) @@ -8336,7 +8436,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { bcast_sum, softmax}); - auto tidx = new Int(); + auto tidx = IrBuilder::create(); fusion.addInput(tidx); for (auto tensor : all_tensors) { @@ -8363,7 +8463,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input, 128}); auto cg_outputs = fe.runFusion({aten_input, 128}); testValidate( @@ -8375,7 +8475,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { +TEST_F(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8401,7 +8501,7 @@ TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { auto lparams = reduction_params.value().lparams; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -8415,7 +8515,7 @@ TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { lparams); } -TEST(NVFuserTest, TestMaskSoftmax_CUDA) { +TEST_F(NVFuserTest, FusionTestMaskSoftmax_CUDA) { // This test is testing the usage of all padding tokens // with softmax like Bert might might use in a full padding // sequence. @@ -8456,7 +8556,7 @@ TEST(NVFuserTest, TestMaskSoftmax_CUDA) { auto lparams = reduction_params.value().lparams; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input, aten_mask}, lparams); auto cg_outputs = fe.runFusion({aten_input, aten_mask}, lparams); testValidate( @@ -8470,7 +8570,7 @@ TEST(NVFuserTest, TestMaskSoftmax_CUDA) { lparams); } -TEST(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { +TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -8558,13 +8658,13 @@ TEST(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { +TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); std::vector input_shape{20, 100, 35, 67}; std::vector norm_shape{67}; @@ -8594,7 +8694,7 @@ TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { auto lparams = reduction_params.value().lparams; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -8610,8 +8710,9 @@ TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { lparams); } -TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } auto fusion = std::make_unique(); @@ -8633,8 +8734,8 @@ TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { fusion->addInput(running_mean); fusion->addInput(running_var); - Double* momentum = new Double(kMomentum); - Double* eps = new Double(kEps); + Double* momentum = IrBuilder::create(kMomentum); + Double* eps = IrBuilder::create(kEps); auto result = batch_norm( input, weight, bias, running_mean, running_var, kTraining, momentum, eps); @@ -8681,7 +8782,7 @@ TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { ""); } -TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { +TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8697,12 +8798,12 @@ TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { TensorView* max_sx = reductionOp( BinaryOpType::Max, {-1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), sx); // (M) TensorView* max_dx = reductionOp( BinaryOpType::Max, {-1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), dx); // (M) // Reduction => merge local and shared memory TensorViews @@ -8804,7 +8905,7 @@ TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { aten_output.narrow(1, static_size, dimy - static_size); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_static_in, aten_dynamic_in}); fe.runFusion( {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out}); @@ -8817,7 +8918,7 @@ TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { +TEST_F(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8830,10 +8931,10 @@ TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { fusion.addInput(sx); fusion.addInput(dx); - Double* gamma = new Double(); - Double* beta = new Double(); - Double* eps = new Double(); - Int* N = new Int(); + Double* gamma = IrBuilder::create(); + Double* beta = IrBuilder::create(); + Double* eps = IrBuilder::create(); + Int* N = IrBuilder::create(); fusion.addInput(gamma); fusion.addInput(beta); fusion.addInput(eps); @@ -8982,7 +9083,7 @@ TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out}); auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); @@ -9003,16 +9104,16 @@ TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views auto x = makeSymbolicTensor(2); - Double* gamma = new Double(); - Double* beta = new Double(); - Double* eps = new Double(); - Int* N = new Int(); + Double* gamma = IrBuilder::create(); + Double* beta = IrBuilder::create(); + Double* eps = IrBuilder::create(); + Int* N = IrBuilder::create(); fusion.addInput(x); fusion.addInput(gamma); fusion.addInput(beta); @@ -9062,7 +9163,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { norm_gamma, norm_gamma_beta}); - auto tidx = new Int(); + auto tidx = IrBuilder::create(); fusion.addInput(tidx); for (auto tensor : all_tensors) { @@ -9105,20 +9206,21 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { aten_input, kGamma, kBeta, kEps, dimy, TIDX}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addInput(tv0); fusion.addOutput(tv1); // tv1[I0, R1] = tv0[I0, I1] @@ -9150,7 +9252,7 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -9165,12 +9267,12 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Algorithm - Int* sym_bsx = new Int(); + Int* sym_bsx = IrBuilder::create(); TensorView* tv0 = makeSymbolicTensor(3); // M, K, N fusion.addInput(tv0); fusion.addInput(sym_bsx); @@ -9213,7 +9315,7 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input, runtime_threadIdx_dim}, lparams); auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams); testValidate( @@ -9229,11 +9331,11 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Int* sym_bsx = new Int(); + Int* sym_bsx = IrBuilder::create(); TensorView* tv0 = makeSymbolicTensor(2); // (M, K) TensorView* tv1 = makeSymbolicTensor(2); // (K, N) TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) @@ -9278,7 +9380,7 @@ TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { LaunchParams lparams(-1, -1, -1, BSX, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( @@ -9294,14 +9396,16 @@ TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = new Int(); // bound to threadIdx.z - Int* symbolic_split_k_tile_dim = new Int(); // bound to blockIdx.x - Int* symbolic_block_k_tile_dim = new Int(); // bound to threadIdx.x + Int* symbolic_m_tile_dim = IrBuilder::create(); // bound to threadIdx.z + Int* symbolic_split_k_tile_dim = + IrBuilder::create(); // bound to blockIdx.x + Int* symbolic_block_k_tile_dim = + IrBuilder::create(); // bound to threadIdx.x // Compile-time integer for tiling int n_smem_tile = 8; // bound to threadIdx.y @@ -9397,10 +9501,6 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - FusionExecutor fe; - // Generate CUDA and compile with nvRTC - fe.compileFusion(&fusion); - // Runtime tiling int m_tile = 4; // bound to threadIdx.z int split_k = 7; // bound to blockIdx.x @@ -9410,6 +9510,9 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); + FusionExecutor fe; + // Generate CUDA and compile with nvRTC + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -9418,13 +9521,14 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { +TEST_F(NVFuserTest, FusionGlobalIntermediate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addInput(tv0); fusion.addOutput(tv1); // tv1[I0, R1] = tv0[I0, I1] @@ -9455,7 +9559,7 @@ TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}, lparams); auto cg_outputs = fe.runFusion({input}, lparams); auto aten_output = input.to(at::kDouble).sum({1}); @@ -9470,7 +9574,7 @@ TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { lparams); } -TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { +TEST_F(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9504,18 +9608,18 @@ TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1, t2, t3}); auto cg_outputs = fe.runFusion({t0, t1, t2, t3}); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConstCheck_CUDA) { +TEST_F(NVFuserTest, FusionConstCheck_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - auto one = new Int(1); + auto one = IrBuilder::create(1); TORCH_CHECK(one->isConstScalar()); auto one_x2 = mul(one, one); @@ -9528,7 +9632,7 @@ TEST(NVFuserTest, FusionConstCheck_CUDA) { TORCH_CHECK(one_x4->isConstScalar()); } -TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { +TEST_F(NVFuserTest, FusionUnrollWithAlloc_CUDA) { const std::vector tensor_dims_in = {128, 128}; Fusion fusion; FusionGuard fg(&fusion); @@ -9537,8 +9641,9 @@ TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(0)); - TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1); + TensorView* tv1 = add(tv0, IrBuilder::create(0)); + TensorView* tv2 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); fusion.addOutput(tv2); const auto options = @@ -9562,7 +9667,7 @@ TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { tv1->computeAt(tv2_rf, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = (input + 0).to(at::kDouble).sum(1); @@ -9571,12 +9676,12 @@ TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { } // Test isZeroInt -TEST(NVFuserTest, FusionIsZeroInt_CUDA) { +TEST_F(NVFuserTest, FusionIsZeroInt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Int* x = new Int(0); - Int* y = new Int(1); + Int* x = IrBuilder::create(0); + Int* y = IrBuilder::create(1); Val* z = mul(x, y); TORCH_CHECK(x->isZeroInt()); TORCH_CHECK(!y->isZeroInt()); @@ -9584,12 +9689,12 @@ TEST(NVFuserTest, FusionIsZeroInt_CUDA) { } // Test isOneInt -TEST(NVFuserTest, FusionIsOneInt_CUDA) { +TEST_F(NVFuserTest, FusionIsOneInt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Int* x = new Int(1); - Int* y = new Int(1); + Int* x = IrBuilder::create(1); + Int* y = IrBuilder::create(1); Val* z = mul(x, y); TORCH_CHECK(x->isOneInt()); TORCH_CHECK(y->isOneInt()); @@ -9599,7 +9704,7 @@ TEST(NVFuserTest, FusionIsOneInt_CUDA) { // This is to verify no cycle of computeAt is created. A more complex // variation of this pattern appears in one of the Python tests // (test_random_topo). -TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9607,12 +9712,12 @@ TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { fusion.addInput(tv0); // Common intermediate tensor - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); // tv1 -> tv2 - auto tv2 = add(tv1, new Double(2)); + auto tv2 = add(tv1, IrBuilder::create(2)); // tv1 -> tv3 -> tv4 - auto tv3 = add(tv1, new Double(3)); - auto tv4 = add(tv3, new Double(4)); + auto tv3 = add(tv1, IrBuilder::create(3)); + auto tv4 = add(tv3, IrBuilder::create(4)); // NOTE: This should no longer occur as of PR #201. // The order of adding outputs matters. If tv3 is added before tv4, @@ -9639,7 +9744,7 @@ TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { auto t4 = t3 + 4; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); std::vector aten_outputs = {t2, t4, t3}; @@ -9647,7 +9752,7 @@ TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9655,10 +9760,10 @@ TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv0, new Double(2)); - TensorView* tv3 = add(tv1, new Double(3)); - TensorView* tv4 = add(tv1, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); + TensorView* tv4 = add(tv1, IrBuilder::create(4)); fusion.addOutput(tv2); fusion.addOutput(tv3); @@ -9666,9 +9771,6 @@ TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { tv1->computeAt(tv3, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({10, 10}, options); @@ -9684,12 +9786,14 @@ TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9697,11 +9801,11 @@ TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv1, tv3); @@ -9712,9 +9816,6 @@ TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { tv1->computeAt(tv5, -1); tv3->computeAt(tv5, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({10, 10}, options); @@ -9731,13 +9832,15 @@ TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder3_CUDA) { for (const auto i : c10::irange(2)) { Fusion fusion; FusionGuard fg(&fusion); @@ -9745,11 +9848,11 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv1, tv3); @@ -9774,9 +9877,6 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { compute_at_outer->computeAt(tv5, -2); compute_at_inner->computeAt(tv5, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); auto t1 = aten_input + 1; @@ -9792,6 +9892,8 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -9799,25 +9901,25 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { } } -TEST(NVFuserTest, FusionTraversalOrder4_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // First tree TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv1, new Double(3)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); fusion.addOutput(tv2); fusion.addOutput(tv3); // Second tree TensorView* tv4 = makeSymbolicTensor(1); fusion.addInput(tv4); - TensorView* tv5 = add(tv4, new Double(5)); - TensorView* tv6 = add(tv5, new Double(6)); - TensorView* tv7 = add(tv5, new Double(7)); + TensorView* tv5 = add(tv4, IrBuilder::create(5)); + TensorView* tv6 = add(tv5, IrBuilder::create(6)); + TensorView* tv7 = add(tv5, IrBuilder::create(7)); fusion.addOutput(tv6); fusion.addOutput(tv7); @@ -9844,23 +9946,23 @@ TEST(NVFuserTest, FusionTraversalOrder4_CUDA) { at::empty_like(t0, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, cg_outputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv2, tv4); fusion.addOutput(tv1); @@ -9870,9 +9972,6 @@ TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { tv2->computeAt(tv5, -1); tv4->computeAt(tv5, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); std::vector cg_outputs = { @@ -9880,6 +9979,8 @@ TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); auto t1 = aten_input + 1; @@ -9894,16 +9995,16 @@ TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv0, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(2)); TensorView* tv3 = add(tv1, tv2); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); fusion.addOutput(tv4); @@ -9916,9 +10017,6 @@ TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { tv1->computeAt(tv3, -1); tv2->computeAt(tv3, -2); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); @@ -9929,22 +10027,24 @@ TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { at::Tensor cg_output = at::empty_like(aten_input, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv2, tv4); fusion.addOutput(tv5); @@ -9963,9 +10063,6 @@ TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { tv2->computeAt(tv5, -4); tv4->computeAt(tv5, -3); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); @@ -9976,6 +10073,9 @@ TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { auto aten_output = t2 + t4; at::Tensor cg_output = at::empty_like(aten_input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( @@ -9983,7 +10083,7 @@ TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { } // Test predication of grid reduction -TEST(NVFuserTest, FusionThreadPredicate_CUDA) { +TEST_F(NVFuserTest, FusionThreadPredicate_CUDA) { const int gdimx = 4; const int bdimx = 128; @@ -9993,9 +10093,10 @@ TEST(NVFuserTest, FusionThreadPredicate_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1); - TensorView* tv3 = add(tv0, new Double(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(2)); fusion.addOutput(tv3); fusion.addOutput(tv2); @@ -10036,14 +10137,14 @@ TEST(NVFuserTest, FusionThreadPredicate_CUDA) { at::empty_like(aten_input, options), at::empty({numel_x}, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionLSTMCell_CUDA) { +TEST_F(NVFuserTest, FusionLSTMCell_CUDA) { const int hidden_features = 512; const int batch_size = 64; @@ -10116,14 +10217,14 @@ TEST(NVFuserTest, FusionLSTMCell_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10131,7 +10232,7 @@ TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); TensorView* tv2 = broadcast(tv1, {true, false}); TensorView* tv3 = broadcast(tv1, {false, true}); TensorView* tv4 = add(tv2, tv3); @@ -10142,7 +10243,7 @@ TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { ASSERT_ANY_THROW(tv1->computeAt(tv3, -1)); } -TEST(NVFuserTest, FusionReductionHalf_CUDA) { +TEST_F(NVFuserTest, FusionReductionHalf_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10151,7 +10252,7 @@ TEST(NVFuserTest, FusionReductionHalf_CUDA) { fusion.addInput(tv0); auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = add(tv1, new Double(1.0)); + auto tv2 = add(tv1, IrBuilder::create(1.0)); auto tv3 = sum(tv2, {2}); auto tv4 = castOp(DataType::Half, tv3); @@ -10172,7 +10273,7 @@ TEST(NVFuserTest, FusionReductionHalf_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); @@ -10189,7 +10290,7 @@ TEST(NVFuserTest, FusionReductionHalf_CUDA) { lparams); } -TEST(NVFuserTest, FusionReduceSingle_CUDA) { +TEST_F(NVFuserTest, FusionReduceSingle_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10205,7 +10306,7 @@ TEST(NVFuserTest, FusionReduceSingle_CUDA) { // Grab only tensor views, though there shouldn't be any other type FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}); @@ -10214,7 +10315,7 @@ TEST(NVFuserTest, FusionReduceSingle_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -10226,8 +10327,8 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim, 2}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim, 2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -10241,7 +10342,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({red_dim, 2}); @@ -10257,7 +10358,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { lparams); } -TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -10269,10 +10370,11 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); - TensorView* tv2 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv1); + TensorView* tv2 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv1); fusion.addOutput(tv2); const auto options = @@ -10287,7 +10389,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({1, 2}); @@ -10303,7 +10405,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { lparams); } -TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -10315,10 +10417,11 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); - TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1); + TensorView* tv2 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); fusion.addOutput(tv2); const auto options = @@ -10332,7 +10435,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({2, 1}); @@ -10348,24 +10451,27 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { lparams); } -TEST(NVFuserTest, FusionTrivialReduction_CUDA) { +TEST_F(NVFuserTest, FusionTrivialReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeConcreteTensor({10, 20, 1}); fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(!fusion.hasReduction(), "Trivial reduction picked up by fusion"); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).empty(), + "Trivial reduction picked up by fusion"); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({10, 20, 1}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); auto aten_output = aten_input.to(at::kDouble).sum({2}); @@ -10373,7 +10479,7 @@ TEST(NVFuserTest, FusionTrivialReduction_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTrivialReduction2_CUDA) { +TEST_F(NVFuserTest, FusionTrivialReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10400,14 +10506,14 @@ TEST(NVFuserTest, FusionTrivialReduction2_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTrivialReduction3_CUDA) { +TEST_F(NVFuserTest, FusionTrivialReduction3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10433,7 +10539,7 @@ TEST(NVFuserTest, FusionTrivialReduction3_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( @@ -10442,7 +10548,7 @@ TEST(NVFuserTest, FusionTrivialReduction3_CUDA) { // Make sure trivial reductions are correctly detected even with // scheduling applied. -TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { +TEST_F(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10459,8 +10565,8 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { auto tv4 = tv2->rFactor({-1}); auto tv5 = broadcast(tv0, {true, false}); - auto tv6 = add(tv5, new Double(1)); - auto tv7 = sub(tv6, new Double(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + auto tv7 = sub(tv6, IrBuilder::create(1)); auto tv8 = sum(tv7, {0}); fusion.addOutput(tv8); @@ -10483,10 +10589,10 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { GpuLower gpulw(&fusion); - // No kir::ReductionOp should be generated as all the reduction + // No ReductionOp should be generated as all the reduction // exprs should be replaced with a unary set op. - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - TORCH_CHECK(!kir_node->isA()); + for (const auto expr : gpulw.kernel()->as()->exprs()) { + TORCH_CHECK(!expr->isA()); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -10494,7 +10600,7 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -10502,14 +10608,14 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { } // Test detection of partially trivial reduction -TEST(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { +TEST_F(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(1, 1); @@ -10525,17 +10631,17 @@ TEST(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { GpuLower gpulw(&fusion); // tv3's reduction axis is a trivial reduction. The only - // kir::ReductionOp should be for tv1. - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (kir_node->isA()) { + // ReductionOp should be for tv1. + for (const auto expr : gpulw.kernel()->as()->exprs()) { + if (expr->isA()) { auto reduction_out = - kir_node->as()->outputs()[0]->as(); - TORCH_CHECK(reduction_out->fuserTv() == tv1); + expr->as()->outputs()[0]->as(); + TORCH_CHECK(reduction_out->name() == 1); } } } -TEST(NVFuserTest, FusionInputsIdLookup_CUDA) { +TEST_F(NVFuserTest, FusionInputsIdLookup_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({16, 8, 8}, options); at::Tensor t1 = at::randn({8, 8}, options); @@ -10573,7 +10679,7 @@ TEST(NVFuserTest, FusionInputsIdLookup_CUDA) { TORCH_CHECK(id_1_relook.eviction == false); } -TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { std::vector sizes_vec({16, 8, 8}); std::vector strides_vec({64, 8, 1}); auto tensor_type = TensorType::create( @@ -10610,7 +10716,7 @@ TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { TORCH_CHECK(complyWith(t6, TensorType::create(t6))); } -TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { std::vector sizes_vec({16, 1, 8}); std::vector strides_vec({8, 8, 1}); auto tensor_type = TensorType::create( @@ -10634,7 +10740,7 @@ TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { TORCH_CHECK(complyWith(t3, tensor_type)); } -TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { std::vector sizes_vec({16, 8, 8}); std::vector strides_vec({64, 1, 8}); auto tensor_type = TensorType::create( @@ -10650,7 +10756,7 @@ TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { TORCH_CHECK(complyWith(t1, tensor_type)); } -TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { std::vector sizes_vec({16, 8, 8}); std::vector strides_vec({128, 16, 1}); auto tensor_type = TensorType::create( @@ -10666,7 +10772,7 @@ TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { TORCH_CHECK(complyWith(t1, tensor_type)); } -TEST(NVFuserTest, FusionDisjointSet_CUDA) { +TEST_F(NVFuserTest, FusionDisjointSet_CUDA) { DisjointSet set; const std::set group_x({0, 1, 2}); @@ -10779,7 +10885,7 @@ TEST(NVFuserTest, FusionDisjointSet_CUDA) { } } -TEST(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { +TEST_F(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10802,7 +10908,7 @@ TEST(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { ASSERT_ANY_THROW(tv3->computeAt(tv4, -1)); } -TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) { +TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10819,14 +10925,14 @@ TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto t3 = castOp(DataType::Float, t2); auto t4 = broadcast(t1, {true, true, false}); auto t5 = add(t4, t3); - auto t6 = mul(t5, new Double(0.5)); - auto t7 = mul(t5, new Double(k_079)); - auto t8 = mul(t5, new Double(k_004)); + auto t6 = mul(t5, IrBuilder::create(0.5)); + auto t7 = mul(t5, IrBuilder::create(k_079)); + auto t8 = mul(t5, IrBuilder::create(k_004)); auto t9 = mul(t8, t5); - auto t10 = add(t9, new Int(1)); + auto t10 = add(t9, IrBuilder::create(1)); auto t11 = mul(t7, t10); auto t12 = unaryOp(UnaryOpType::Tanh, t11); - auto t13 = add(t12, new Double(1)); + auto t13 = add(t12, IrBuilder::create(1)); auto t14 = mul(t6, t13); auto t15 = castOp(DataType::Half, t14); fusion.addOutput(t15); @@ -10849,15 +10955,14 @@ TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { +TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { if (at::cuda::getDeviceProperties(0)->major < 6) { return; } @@ -10882,23 +10987,23 @@ TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { auto t5 = castOp(DataType::Float, t4); auto t6 = broadcast(t3, {true, true, false}); auto t7 = add(t6, t5); - auto t8 = mul(t7, new Double(k_079)); - auto t9 = mul(t7, new Double(k_004)); + auto t8 = mul(t7, IrBuilder::create(k_079)); + auto t9 = mul(t7, IrBuilder::create(k_004)); auto t10 = mul(t9, t7); - auto t11 = add(t10, new Int(1)); + auto t11 = add(t10, IrBuilder::create(1)); auto t12 = mul(t8, t11); auto t13 = unaryOp(UnaryOpType::Tanh, t12); - auto t14 = mul(t7, new Double(0.5)); + auto t14 = mul(t7, IrBuilder::create(0.5)); auto t15 = mul(t13, t13); auto t16 = unaryOp(UnaryOpType::Neg, t15); - auto t17 = add(t16, new Int(1)); - auto t18 = mul(t7, new Double(k_010)); + auto t17 = add(t16, IrBuilder::create(1)); + auto t18 = mul(t7, IrBuilder::create(k_010)); auto t19 = mul(t18, t7); - auto t20 = add(t19, new Double(k_079)); + auto t20 = add(t19, IrBuilder::create(k_079)); auto t21 = mul(t17, t20); auto t22 = mul(t14, t21); - auto t23 = add(t13, new Int(1)); - auto t24 = mul(t23, new Double(0.5)); + auto t23 = add(t13, IrBuilder::create(1)); + auto t24 = mul(t23, IrBuilder::create(0.5)); auto t25 = add(t22, t24); auto t26 = mul(t25, t1); // Save float output for validation @@ -10929,8 +11034,7 @@ TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( @@ -10938,7 +11042,7 @@ TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { } // Reproducer of issue #459 -TEST(NVFuserTest, FusionIssue459_CUDA) { +TEST_F(NVFuserTest, FusionIssue459_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10947,14 +11051,14 @@ TEST(NVFuserTest, FusionIssue459_CUDA) { auto tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = add(tv1, tv3); // Create two outputs from the final arithmetic result - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); - auto tv6 = add(tv4, new Double(1)); + auto tv6 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv6); // Scheduling @@ -10981,8 +11085,7 @@ TEST(NVFuserTest, FusionIssue459_CUDA) { std::vector aten_inputs = {t0, t1}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -10994,15 +11097,15 @@ TEST(NVFuserTest, FusionIssue459_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionSmemIndexingSimple_CUDA) { +TEST_F(NVFuserTest, FusionSmemIndexingSimple_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); tv3->axis(0)->parallelize(ParallelType::BIDx); @@ -11013,28 +11116,27 @@ TEST(NVFuserTest, FusionSmemIndexingSimple_CUDA) { tv1->setMemoryType(MemoryType::Shared); tv2->setMemoryType(MemoryType::Global); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto aten_input = at::randn({12, 34}, options); at::Tensor aten_output = aten_input + 1.0 + 1.0 + 1.0; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSmemIndexing_CUDA) { +TEST_F(NVFuserTest, FusionSmemIndexing_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = new Int(); - Int* symbolic_split_k_tile_dim = new Int(); - Int* symbolic_block_k_tile_dim = new Int(); + Int* symbolic_m_tile_dim = IrBuilder::create(); + Int* symbolic_split_k_tile_dim = IrBuilder::create(); + Int* symbolic_block_k_tile_dim = IrBuilder::create(); // Compile-time integer for tiling int n_smem_tile = 32; @@ -11131,9 +11233,8 @@ TEST(NVFuserTest, FusionSmemIndexing_CUDA) { // A, B, m_tile_dim, split_k, intra_cta_tile std::vector aten_inputs = {t0, t1, 3, 4, 5}; - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -11141,13 +11242,13 @@ TEST(NVFuserTest, FusionSmemIndexing_CUDA) { } // Reproducer of issue 408 -TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) { +TEST_F(NVFuserTest, FusionCacheBeforeReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); fusion.addOutput(tv2); @@ -11160,9 +11261,6 @@ TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -11172,21 +11270,23 @@ TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) { auto aten_output = (aten_input + 1).to(at::kDouble).sum({1}); + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { +TEST_F(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(3); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv2); fusion.addOutput(tv3); @@ -11201,9 +11301,6 @@ TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 10; const int numel_y = 20; const int numel_z = 30; @@ -11214,20 +11311,22 @@ TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { auto t3 = t2 + 1; std::vector aten_outputs = {t2, t3}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue367_CUDA) { +TEST_F(NVFuserTest, FusionIssue367_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = new Int(); - Int* symbolic_split_k_tile_dim = new Int(); - Int* symbolic_block_k_tile_dim = new Int(); + Int* symbolic_m_tile_dim = IrBuilder::create(); + Int* symbolic_split_k_tile_dim = IrBuilder::create(); + Int* symbolic_block_k_tile_dim = IrBuilder::create(); // Compile-time integer for tiling int n_smem_tile = 32; @@ -11320,14 +11419,14 @@ TEST(NVFuserTest, FusionIssue367_CUDA) { mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue468_CUDA) { +TEST_F(NVFuserTest, FusionIssue468_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -11346,15 +11445,15 @@ TEST(NVFuserTest, FusionIssue468_CUDA) { at::Tensor aten_input = at::randn({10, 100}, options); at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}).sum({0}); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue363_CUDA) { +TEST_F(NVFuserTest, FusionIssue363_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -11402,21 +11501,21 @@ TEST(NVFuserTest, FusionIssue363_CUDA) { std::vector aten_inputs = {t0, t1}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue484_CUDA) { +TEST_F(NVFuserTest, FusionIssue484_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, new Double(0)); + auto tv2 = add(tv1, IrBuilder::create(0)); fusion.addOutput(tv2); tv1->setMemoryType(MemoryType::Global); @@ -11430,20 +11529,20 @@ TEST(NVFuserTest, FusionIssue484_CUDA) { at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue329_CUDA) { +TEST_F(NVFuserTest, FusionIssue329_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); fusion.addOutput(tv2); auto tv3 = sum(tv1, {1}); @@ -11460,22 +11559,21 @@ TEST(NVFuserTest, FusionIssue329_CUDA) { std::vector aten_outputs = {t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue382_CUDA) { +TEST_F(NVFuserTest, FusionIssue382_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = broadcast(tv1, {false, false, true}); auto tv3 = makeSymbolicTensor(3); fusion.addInput(tv3); @@ -11492,9 +11590,6 @@ TEST(NVFuserTest, FusionIssue382_CUDA) { tv1->setMemoryType(MemoryType::Global); tv2->setMemoryType(MemoryType::Global); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 12; const int numel_y = 34; const int numel_z = 56; @@ -11507,20 +11602,22 @@ TEST(NVFuserTest, FusionIssue382_CUDA) { std::vector aten_inputs = {t0, t3}; auto aten_output = (t0 + 1).unsqueeze(-1) + t3; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue507_CUDA) { +TEST_F(NVFuserTest, FusionIssue507_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->setMemoryType(MemoryType::Shared); @@ -11538,22 +11635,21 @@ TEST(NVFuserTest, FusionIssue507_CUDA) { auto aten_output = (t1 + 1); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue532_CUDA) { +TEST_F(NVFuserTest, FusionIssue532_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(1)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(1)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -11579,7 +11675,7 @@ TEST(NVFuserTest, FusionIssue532_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0 + 1 + 1; @@ -11588,14 +11684,14 @@ TEST(NVFuserTest, FusionIssue532_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionLoopUnswitch_CUDA) { +TEST_F(NVFuserTest, FusionLoopUnswitch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(1)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(1)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -11612,7 +11708,7 @@ TEST(NVFuserTest, FusionLoopUnswitch_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0 + 1 + 1; @@ -11621,7 +11717,7 @@ TEST(NVFuserTest, FusionLoopUnswitch_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue549_CUDA) { +TEST_F(NVFuserTest, FusionIssue549_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -11631,7 +11727,7 @@ TEST(NVFuserTest, FusionIssue549_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = broadcast(tv2, {false, false, true}); // tv3[I0, I1, B] = tv0[I0, I1] @@ -11689,10 +11785,12 @@ TEST(NVFuserTest, FusionIssue549_CUDA) { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - FusionExecutor fe; - fe.compileFusion(&fusion); // Lets specify a few bounds in launch params to make sure it works - fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); + LaunchParams lparams(1, -1, -1, 32, 4, 4); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + fe.runFusion({t0, t1}, lparams); // Make sure bad launch params throws // TODO: Re-enable once we have parallelization validation in. @@ -11707,7 +11805,7 @@ TEST(NVFuserTest, FusionIssue549_CUDA) { &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, simplecompileRtc_CUDA) { +TEST_F(NVFuserTest, FusionSimpleCompileRtc_CUDA) { FusionExecutor fe; std::string kernel = R"( __global__ void kernel1(Tensor T0, Tensor T1) { @@ -11739,7 +11837,7 @@ __global__ void kernel1(Tensor T0, Tensor T1) { TORCH_CHECK(out_ref.allclose(out0)); } -TEST(NVFuserTest, FusionSerialWelford_CUDA) { +TEST_F(NVFuserTest, FusionSerialWelford_CUDA) { FusionExecutor fe; int x = 128, y = 64, z = 64; @@ -11796,7 +11894,7 @@ __global__ void kernel1( TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); } -TEST(NVFuserTest, FusionBlockWelford_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelford_CUDA) { FusionExecutor fe; int x = 7, y = 8, z = 9; @@ -11884,7 +11982,7 @@ __global__ void kernel1( cat_tensor.mean({1}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); } -TEST(NVFuserTest, FusionBlockWelfordNoInit_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelfordNoInit_CUDA) { FusionExecutor fe; int x = 7, y = 8, z = 9; @@ -11950,7 +12048,7 @@ __global__ void kernel1( TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); } -TEST(NVFuserTest, FusionGridWelfordNoInit_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordNoInit_CUDA) { FusionExecutor fe; int x = 128, y = 64, z = 128; @@ -12040,7 +12138,7 @@ __global__ void kernel1( TORCH_CHECK(in0.var(dims, false).allclose(out_var)); } -TEST(NVFuserTest, FusionWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12048,7 +12146,7 @@ TEST(NVFuserTest, FusionWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12069,7 +12167,7 @@ TEST(NVFuserTest, FusionWelfordOp_CUDA) { at::Tensor t0 = at::randn({M, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12084,7 +12182,7 @@ TEST(NVFuserTest, FusionWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12092,7 +12190,7 @@ TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12115,7 +12213,7 @@ TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { at::Tensor t_N = at::empty({M}, options_int); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12130,7 +12228,7 @@ TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12138,7 +12236,7 @@ TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12161,7 +12259,7 @@ TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { at::Tensor t_N = at::empty({M}, options_int); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12176,7 +12274,7 @@ TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionRfactorWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12184,7 +12282,7 @@ TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12206,7 +12304,7 @@ TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { at::Tensor t_N = at::empty({M}, options_int); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12221,7 +12319,7 @@ TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionWelfordSchedule_CUDA) { +TEST_F(NVFuserTest, FusionWelfordSchedule_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12229,7 +12327,7 @@ TEST(NVFuserTest, FusionWelfordSchedule_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12246,9 +12344,10 @@ TEST(NVFuserTest, FusionWelfordSchedule_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {t0}); scheduleReduction(&fusion, reduction_params.value()); + auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0}, reduction_params.value().lparams); + fe.compileFusion(&fusion, {t0}, lparams); + auto outputs = fe.runFusion({t0}, lparams); // by default Welford outputs sum of square diff so need to divide to get var outputs[1] /= N; @@ -12283,7 +12382,7 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { tv0_cast = castOp(DataType::Float, tv0); } fusion.addInput(tv0); - auto tv1 = mul(tv0_cast, new Double(1)); + auto tv1 = mul(tv0_cast, IrBuilder::create(1)); auto tvs = Welford(tv1, {axis}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12324,8 +12423,8 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({aten_input}, reduction_params.value().lparams); + fe.compileFusion(&fusion, {aten_input}, lparams); + auto outputs = fe.runFusion({aten_input}, lparams); // by default Welford outputs sum of square diff so need to divide to // get var @@ -12351,7 +12450,7 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { } } // namespace -TEST(NVFuserTest, FusionWelfordShmoo_CUDA) { +TEST_F(NVFuserTest, FusionWelfordShmoo_CUDA) { std::vector dtypes = { DataType::Double, DataType::Float, DataType::Half}; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 @@ -12393,7 +12492,7 @@ TEST(NVFuserTest, FusionWelfordShmoo_CUDA) { } } -TEST(NVFuserTest, FusionTranspose1_CUDA) { +TEST_F(NVFuserTest, FusionTranspose1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12414,7 +12513,7 @@ TEST(NVFuserTest, FusionTranspose1_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0.t(); @@ -12423,7 +12522,7 @@ TEST(NVFuserTest, FusionTranspose1_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTranspose2_CUDA) { +TEST_F(NVFuserTest, FusionTranspose2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12447,7 +12546,7 @@ TEST(NVFuserTest, FusionTranspose2_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0.t(); @@ -12456,7 +12555,7 @@ TEST(NVFuserTest, FusionTranspose2_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { +TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12526,10 +12625,11 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { at::Tensor t0 = at::randn({K, M}, options); at::Tensor t1 = at::randn({N, K}, options); - FusionExecutor fe; - fe.compileFusion(&fusion); // Lets specify a few bounds in launch params to make sure it works - fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); + LaunchParams lparams(1, -1, -1, 32, 4, 4); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + fe.runFusion({t0, t1}, lparams); // Don't specify any launch params auto cg_outputs = fe.runFusion({t0, t1}); @@ -12540,7 +12640,7 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12593,7 +12693,7 @@ TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_input_t = at::transpose(input, 1, 2); @@ -12603,7 +12703,7 @@ TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -12620,10 +12720,10 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { tv0 = transpose(tv0, {{0, 1}}); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = add(tv1, new Double(3.0)); - TensorView* tv4 = mul(tv1, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); @@ -12654,7 +12754,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -12667,7 +12767,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { at::Tensor aten_input = at::randn({129, 127}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); at::Tensor aten_input_t = aten_input.t(); @@ -12686,7 +12786,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -12702,9 +12802,9 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { tv0 = transpose(tv0, {{0, 1}}); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); @@ -12723,7 +12823,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { tv0->computeAt(tv6, 1); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -12736,7 +12836,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { at::Tensor input = at::randn({129, 127}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto input_t = input.t(); @@ -12752,7 +12852,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 @@ -12769,7 +12869,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); - TensorView* tv2 = mul(tv1, new Double(.979361)); + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); @@ -12786,7 +12886,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -12802,7 +12902,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t0_t = t0.permute({3, 0, 1, 2}); @@ -12814,7 +12914,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 @@ -12862,7 +12962,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -12880,7 +12980,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t0_t = t0.permute({3, 0, 1, 2}); @@ -12895,7 +12995,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 @@ -12909,7 +13009,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); tv1 = transpose(tv1, {{0, 1}}); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -12928,7 +13028,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t2 = t0.t().add(2.0); @@ -12938,7 +13038,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12948,7 +13048,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); tv1 = transpose(tv1, {{0, 1}}); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -12970,7 +13070,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t2 = t0.t().add(2.0); @@ -12980,7 +13080,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) { +TEST_F(NVFuserTest, FusionSegmentReducePointwise_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -12992,7 +13092,7 @@ TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) { fusion->addInput(tv1); fusion->addInput(tv2); - TensorView* tv3 = add(tv0, new Double(1)); // Group 0 + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 TensorView* tv4 = max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, @@ -13029,7 +13129,7 @@ TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) { executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultipleVectorize_CUDA) { +TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -13089,7 +13189,7 @@ TEST(NVFuserTest, FusionMultipleVectorize_CUDA) { TORCH_CHECK(runtime1 != runtime3); } -TEST(NVFuserTest, FusionVectorizeSimple_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeSimple_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13123,7 +13223,7 @@ TEST(NVFuserTest, FusionVectorizeSimple_CUDA) { at::Tensor aten_input = at::empty({2, 6, 32}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); at::Tensor aten_output = aten_input.sin(); @@ -13132,7 +13232,7 @@ TEST(NVFuserTest, FusionVectorizeSimple_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { +TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // dimensionality of the problem @@ -13148,7 +13248,7 @@ TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -13197,7 +13297,7 @@ TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { at::Tensor output = at::empty_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; @@ -13206,7 +13306,7 @@ TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { +TEST_F(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -13220,7 +13320,7 @@ TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { fusion->addInput(tv0); - auto tv1 = add(tv0, new Double(1.0)); + auto tv1 = add(tv0, IrBuilder::create(1.0)); auto tv2 = sum(tv1, {2}); // Group 0 auto output = softmax(tv2, kReductionAxis); // Group 1 @@ -13247,14 +13347,14 @@ TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSwizzle1_CUDA) { +TEST_F(NVFuserTest, FusionSwizzle1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = mul(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = mul(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); tv2->split(0, 7); @@ -13279,7 +13379,7 @@ TEST(NVFuserTest, FusionSwizzle1_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = (t0 + 1) * 2; @@ -13288,14 +13388,14 @@ TEST(NVFuserTest, FusionSwizzle1_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSwizzle2_CUDA) { +TEST_F(NVFuserTest, FusionSwizzle2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = mul(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = mul(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -13323,7 +13423,7 @@ TEST(NVFuserTest, FusionSwizzle2_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = (t0 + 1) * 2; @@ -13332,7 +13432,7 @@ TEST(NVFuserTest, FusionSwizzle2_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { +TEST_F(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13385,8 +13485,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.t(); @@ -13395,7 +13494,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { +TEST_F(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13452,8 +13551,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.t(); @@ -13462,10 +13560,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridPersistence_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridPersistence_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13490,7 +13585,7 @@ TEST(NVFuserTest, FusionGridPersistence_CUDA) { at::Tensor input = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}).unsqueeze(-1).add(input); @@ -13498,10 +13593,7 @@ TEST(NVFuserTest, FusionGridPersistence_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridPersistence2_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridPersistence2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13528,7 +13620,7 @@ TEST(NVFuserTest, FusionGridPersistence2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}).unsqueeze(0).add(input); @@ -13536,10 +13628,7 @@ TEST(NVFuserTest, FusionGridPersistence2_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWelfordPersistence_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionWelfordPersistence_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13567,7 +13656,7 @@ TEST(NVFuserTest, FusionWelfordPersistence_CUDA) { at::Tensor input = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) @@ -13577,10 +13666,7 @@ TEST(NVFuserTest, FusionWelfordPersistence_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWelfordPersistence2_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionWelfordPersistence2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13610,7 +13696,7 @@ TEST(NVFuserTest, FusionWelfordPersistence2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) @@ -13620,7 +13706,7 @@ TEST(NVFuserTest, FusionWelfordPersistence2_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue633_CUDA) { +TEST_F(NVFuserTest, FusionIssue633_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13642,14 +13728,13 @@ TEST(NVFuserTest, FusionIssue633_CUDA) { tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dx, dy, dz}, options); at::Tensor t1 = at::randn({dx, dy, 1}, options); std::vector aten_inputs = {t0, t1}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13658,48 +13743,7 @@ TEST(NVFuserTest, FusionIssue633_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionKirScoping_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); - fusion.addOutput(tv2); - - tv2->merge(0); - tv2->split(0, 4); - tv0->computeAt(tv2, -1); - - GpuLower gpulw(&fusion); - - auto kir_tv1 = gpulw.lowerValue(tv1); - auto tv1_scope = kir_tv1->definition()->scope(); - TORCH_CHECK(tv1_scope != nullptr); - TORCH_CHECK(tv1_scope->owner()->as()); - - auto kir_tv2 = gpulw.lowerValue(tv2); - auto tv2_scope = kir_tv2->definition()->scope(); - TORCH_CHECK(tv2_scope != nullptr); - TORCH_CHECK(tv2_scope->owner()->as()); - - TORCH_CHECK(tv1_scope != tv2_scope); - - // tv1 and tv2 should have the same inner-most ForLoop - auto parent_scope = tv1_scope->owner()->scope(); - TORCH_CHECK(parent_scope == tv2_scope->owner()->scope()); - TORCH_CHECK(parent_scope->owner()->as()); - // There should be one more loop - parent_scope = parent_scope->owner()->scope(); - TORCH_CHECK(parent_scope->owner()->as()); - - // scope() should return nullptr for top-level exprs - auto top_level_scope = parent_scope->owner()->scope(); - TORCH_CHECK(top_level_scope == nullptr); -} - -TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { +TEST_F(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13726,7 +13770,7 @@ TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t3 = t0.unsqueeze(-1).expand(shape) + t1; @@ -13734,7 +13778,7 @@ TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13777,7 +13821,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13785,7 +13829,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13835,7 +13879,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13843,7 +13887,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13896,7 +13940,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13904,7 +13948,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13963,7 +14007,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14013,7 +14057,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.add(t1).sum(1); @@ -14021,7 +14065,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14059,7 +14103,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14100,7 +14144,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -14108,7 +14152,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14151,13 +14195,13 @@ TEST(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); // Failure because the input + output tensors do not have the same stride ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); } -TEST(NVFuserTest, FusionViewOutput_CUDA) { +TEST_F(NVFuserTest, FusionViewOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14181,7 +14225,7 @@ TEST(NVFuserTest, FusionViewOutput_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_add_bias = at_x + at_bias; @@ -14190,7 +14234,7 @@ TEST(NVFuserTest, FusionViewOutput_CUDA) { testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionViewFailMismatchSize_CUDA) { +TEST_F(NVFuserTest, FusionViewFailMismatchSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14210,7 +14254,7 @@ TEST(NVFuserTest, FusionViewFailMismatchSize_CUDA) { ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); } -TEST(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { +TEST_F(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14228,7 +14272,7 @@ TEST(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); } -TEST(NVFuserTest, FusionViewFailReduction_CUDA) { +TEST_F(NVFuserTest, FusionViewFailReduction_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -14259,7 +14303,7 @@ TEST(NVFuserTest, FusionViewFailReduction_CUDA) { ASSERT_ANY_THROW(fusion_executor_cache.runFusionWithInputs({at_x, at_bias})); } -TEST(NVFuserTest, FusionViewFailPersistent_CUDA) { +TEST_F(NVFuserTest, FusionViewFailPersistent_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -14319,7 +14363,7 @@ void addViewGeluFusion( auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_add_bias = at_x + at_bias; @@ -14330,25 +14374,25 @@ void addViewGeluFusion( } } -TEST(NVFuserTest, FusionViewSplit_CUDA) { +TEST_F(NVFuserTest, FusionViewSplit_CUDA) { std::vector input_shape{80}; std::vector output_shape{2, 4, 10}; addViewGeluFusion(input_shape, output_shape); } -TEST(NVFuserTest, FusionViewBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionViewBroadcast_CUDA) { std::vector input_shape{80}; std::vector output_shape{1, 80}; addViewGeluFusion(input_shape, output_shape); } -TEST(NVFuserTest, FusionViewMerge_CUDA) { +TEST_F(NVFuserTest, FusionViewMerge_CUDA) { std::vector input_shape{2, 40, 7}; std::vector output_shape{560}; addViewGeluFusion(input_shape, output_shape); } -TEST(NVFuserTest, FusionViewAllShmoo_CUDA) { +TEST_F(NVFuserTest, FusionViewAllShmoo_CUDA) { typedef std::vector shape; typedef std::pair view_example; @@ -14373,7 +14417,7 @@ TEST(NVFuserTest, FusionViewAllShmoo_CUDA) { } } -TEST(NVFuserTest, FusionViewInferShmoo_CUDA) { +TEST_F(NVFuserTest, FusionViewInferShmoo_CUDA) { typedef std::vector shape; typedef std::pair view_example; @@ -14427,7 +14471,7 @@ void geluViewAddFusion( auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_gelu = at::gelu(at_x); @@ -14438,7 +14482,7 @@ void geluViewAddFusion( } } -TEST(NVFuserTest, FusionViewStride_CUDA) { +TEST_F(NVFuserTest, FusionViewStride_CUDA) { typedef std::vector shape; typedef std::pair view_example; @@ -14483,7 +14527,7 @@ void geluViewBinaryAddFusion( auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_gelu = at::gelu(at_x); @@ -14495,11 +14539,11 @@ void geluViewBinaryAddFusion( } } -TEST(NVFuserTest, FusionViewBinary_CUDA) { +TEST_F(NVFuserTest, FusionViewBinary_CUDA) { geluViewBinaryAddFusion({27454, 2}, {54908}, {7844, 7}); } -TEST(NVFuserTest, FusionVectorization1_CUDA) { +TEST_F(NVFuserTest, FusionVectorization1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14540,7 +14584,7 @@ TEST(NVFuserTest, FusionVectorization1_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -14548,7 +14592,7 @@ TEST(NVFuserTest, FusionVectorization1_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorization2_CUDA) { +TEST_F(NVFuserTest, FusionVectorization2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14586,7 +14630,7 @@ TEST(NVFuserTest, FusionVectorization2_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionVectorization3_CUDA) { +TEST_F(NVFuserTest, FusionVectorization3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14623,11 +14667,10 @@ TEST(NVFuserTest, FusionVectorization3_CUDA) { const int by = 2049; at::Tensor t0 = at::randn({bx, by}, options); at::Tensor t1 = at::randn({bx, by}, options); + std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); - - std::vector aten_inputs = {t0, t1}; + fe.compileFusion(&fusion, aten_inputs); ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); aten_inputs[0] = t0.index({"...", Slice(1)}); @@ -14644,7 +14687,7 @@ TEST(NVFuserTest, FusionVectorization3_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) { +TEST_F(NVFuserTest, FusionVectorizationRFactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14692,7 +14735,7 @@ TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.add(t1).sum(1); @@ -14705,7 +14748,7 @@ TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) { } // Unswitched loops with extent one may omit else clause. -TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { +TEST_F(NVFuserTest, FusionSizeOneLoop1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14742,16 +14785,7 @@ TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { // Make sure the unswitched loop does not have an else clause. GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto fl = dynamic_cast(kir_node.get())) { - if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) { - continue; - } - if (auto pred = dynamic_cast(fl->parentScope())) { - TORCH_CHECK(!pred->hasElse()); - } - } - } + TORCH_CHECK(!UnswitchInElseChecker::check(gpulw)); const int x = 11; const int y = 12; @@ -14763,7 +14797,7 @@ TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { std::vector aten_inputs = {t0, t1, t2}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t6 = (t0.unsqueeze(-1) + t1).unsqueeze(0) + t2; @@ -14772,7 +14806,7 @@ TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { // The unswitched loop has extent one but inner loops don't. The else // part should not be omitted. -TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) { +TEST_F(NVFuserTest, FusionSizeOneLoop2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14780,7 +14814,7 @@ TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) { auto tv0 = makeConcreteTensor({x}); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); tv1->split(-1, 4); @@ -14790,38 +14824,29 @@ TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) { // Make sure the size-one unswitched loop does not omit the else clause. GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto fl = dynamic_cast(kir_node.get())) { - if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) { - continue; - } - if (auto pred = dynamic_cast(fl->parentScope())) { - TORCH_CHECK(pred->hasElse()); - } - } - } + TORCH_CHECK(UnswitchInElseChecker::check(gpulw)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x}, options); std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t1 = t0 + 1; testValidate(&fusion, cg_outputs, aten_inputs, {t1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionValidateParallelize1_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -14832,15 +14857,15 @@ TEST(NVFuserTest, FusionValidateParallelize1_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionValidateParallelize2_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -14853,15 +14878,15 @@ TEST(NVFuserTest, FusionValidateParallelize2_CUDA) { fe.compileFusion(&fusion); } -TEST(NVFuserTest, FusionValidateParallelize3_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -14876,15 +14901,15 @@ TEST(NVFuserTest, FusionValidateParallelize3_CUDA) { fe.compileFusion(&fusion); } -TEST(NVFuserTest, FusionValidateParallelize4_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -14899,15 +14924,15 @@ TEST(NVFuserTest, FusionValidateParallelize4_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionValidateParallelize5_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -14924,7 +14949,7 @@ TEST(NVFuserTest, FusionValidateParallelize5_CUDA) { } // See issue #995 -TEST(NVFuserTest, FusionValidateParallelize6_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14933,7 +14958,7 @@ TEST(NVFuserTest, FusionValidateParallelize6_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -14960,7 +14985,7 @@ TEST(NVFuserTest, FusionValidateParallelize6_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -TEST(NVFuserTest, FusionDAGMerging_CUDA) { +TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14976,7 +15001,7 @@ TEST(NVFuserTest, FusionDAGMerging_CUDA) { auto tv5 = sum(tv4, {0}); // 3 // Branch 1 - auto tv6 = add(tv1, new Double(1)); // 4 + auto tv6 = add(tv1, IrBuilder::create(1)); // 4 // Merge auto tv7 = add(tv6, tv5); // 5 @@ -14995,17 +15020,17 @@ TEST(NVFuserTest, FusionDAGMerging_CUDA) { TORCH_CHECK(fusion_segments->groups().size() <= 4); } -TEST(NVFuserTest, FusionDAGScalarMerging_CUDA) { +TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(3); - auto i0 = new Double(); + auto i0 = IrBuilder::create(); fusion->addInput(tv0); fusion->addInput(i0); - auto i1 = add(i0, new Double(1.0)); + auto i1 = add(i0, IrBuilder::create(1.0)); auto i2 = mul(i1, i1); auto i3 = add(i2, i1); @@ -15051,7 +15076,7 @@ TEST(NVFuserTest, FusionDAGScalarMerging_CUDA) { executor_cache.fusion(), outputs, {t0, s0}, {t5}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { +TEST_F(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15073,14 +15098,14 @@ TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0.sum({1, 2}); testValidate( &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15106,7 +15131,7 @@ TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_avg = t0.mean({1, 2}); at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K; @@ -15115,7 +15140,7 @@ TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { } // See Issue #716 -TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { +TEST_F(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15129,7 +15154,7 @@ TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { std::vector broadcast_mask = {false, true}; auto tv0_bcast = broadcast(tv0, broadcast_mask); - auto path1_bcast = add(tv0_bcast, new Double(1.0)); + auto path1_bcast = add(tv0_bcast, IrBuilder::create(1.0)); auto path1 = sum(path1_bcast, reduction_axes); fusion.addOutput(path1); @@ -15145,7 +15170,7 @@ TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); // inplace op, we are adding t0 to itself auto outputs = fe.runFusion(aten_inputs, {t0}); @@ -15153,7 +15178,7 @@ TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { TORCH_CHECK(outputs[0].allclose(t0_ref.add(1))); } -TEST(NVFuserTest, FusionReductionPredicate_CUDA) { +TEST_F(NVFuserTest, FusionReductionPredicate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15184,7 +15209,7 @@ TEST(NVFuserTest, FusionReductionPredicate_CUDA) { at::Tensor cg_output = at::empty({numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({0}); @@ -15193,7 +15218,7 @@ TEST(NVFuserTest, FusionReductionPredicate_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue728_CUDA) { +TEST_F(NVFuserTest, FusionIssue728_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15204,10 +15229,10 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { auto tv2 = makeSymbolicTensor(1); fusion.addOutput(tv2); - auto tv3 = add(tv0, new Double(1)); + auto tv3 = add(tv0, IrBuilder::create(1)); auto tv4 = add(tv3, tv1); - auto tv5 = add(tv4, new Double(1)); - auto tv6 = add(tv2, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv5); fusion.addOutput(tv6); @@ -15253,7 +15278,7 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { "Only tv3 should be included"); } -TEST(NVFuserTest, FusionIssue757_CUDA) { +TEST_F(NVFuserTest, FusionIssue757_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15281,7 +15306,7 @@ TEST(NVFuserTest, FusionIssue757_CUDA) { std::vector inputs = {t0, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0.sum({1}); @@ -15292,7 +15317,7 @@ TEST(NVFuserTest, FusionIssue757_CUDA) { } // See issue #759 -TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15323,7 +15348,7 @@ TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { std::vector inputs = {t0, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0.sum({1}); @@ -15333,7 +15358,7 @@ TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { +TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -15367,12 +15392,12 @@ TEST(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { TORCH_CHECK(segmented_fusion->groups().size() == 2); } -TEST(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { +TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(3); - auto i0 = new Double(); + auto i0 = IrBuilder::create(); fusion->addInput(tv0); fusion->addInput(i0); @@ -15407,7 +15432,7 @@ TEST(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { TORCH_CHECK(segmented_fusion->groups().size() == 2); } -TEST(NVFuserTest, FusionSegmentMixReduction_CUDA) { +TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -15446,7 +15471,7 @@ TEST(NVFuserTest, FusionSegmentMixReduction_CUDA) { TORCH_CHECK(segmented_fusion->groups().size() <= 2); } -TEST(NVFuserTest, FusionSBAR_CUDA) { +TEST_F(NVFuserTest, FusionSBAR_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15492,12 +15517,11 @@ TEST(NVFuserTest, FusionSBAR_CUDA) { // outputs std::vector outputs; - auto lparams = schedulePointwise(&fusion, c10::ArrayRef(inputs)); + auto lparams = schedulePointwise(&fusion, inputs); FusionExecutor executor; - executor.compileFusion(&fusion); - - outputs = executor.runFusion(c10::ArrayRef(inputs), lparams); + executor.compileFusion(&fusion, inputs, lparams); + outputs = executor.runFusion(inputs, lparams); auto at_scale = at::mul(at_x, at_weight); auto at_scale_bias = at::add(at_scale, at_bias); @@ -15507,16 +15531,16 @@ TEST(NVFuserTest, FusionSBAR_CUDA) { testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSingleElement_CUDA) { +TEST_F(NVFuserTest, FusionSingleElement_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(0); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(2.5)); + auto tv1 = add(tv0, IrBuilder::create(2.5)); - auto tv2 = add(tv1, new Double(3.5)); + auto tv2 = add(tv1, IrBuilder::create(3.5)); fusion.addOutput(tv2); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -15527,7 +15551,7 @@ TEST(NVFuserTest, FusionSingleElement_CUDA) { auto lparams = schedulePointwise(&fusion, {input}); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}, lparams); fe.runFusion({input}, {cg_output}, lparams); auto aten_output = input.add(2.5).add(3.5); @@ -15536,7 +15560,7 @@ TEST(NVFuserTest, FusionSingleElement_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) { +TEST_F(NVFuserTest, FusionBNBackwardRepro_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15566,12 +15590,12 @@ TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) { makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous. fusion.addInput(gt_0); - auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1)); + auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); auto gt_float = castOp(DataType::Float, gt_bool); auto grad_out = mul(grad_out_prev, gt_float); - Val* eps_ptr = new Double(1e-5); + Val* eps_ptr = IrBuilder::create(1e-5); auto grads = batch_norm_backward( input, @@ -15606,7 +15630,7 @@ TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) { } // TODO: We only changed inputs, merge this with the test above. -TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) { +TEST_F(NVFuserTest, FusionBNBackwardRepro2_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15639,12 +15663,12 @@ TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) { auto gt_0 = makeConcreteTensor({-1, -1, 1, 1}); fusion.addInput(gt_0); - auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1)); + auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); auto gt_float = castOp(DataType::Float, gt_bool); auto grad_out = mul(grad_out_prev, gt_float); - Val* eps_ptr = new Double(1e-5); + Val* eps_ptr = IrBuilder::create(1e-5); auto grads = batch_norm_backward( input, @@ -15678,7 +15702,7 @@ TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) { auto outputs = fec.runFusionWithInputs(inputs); } -TEST(NVFuserTest, FusionBNRepro_CUDA) { +TEST_F(NVFuserTest, FusionBNRepro_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15704,8 +15728,8 @@ TEST(NVFuserTest, FusionBNRepro_CUDA) { auto running_var = makeSymbolicTensor(1); fusion.addInput(running_var); - auto momentum_ptr = new Double(kMomentum); - auto eps_ptr = new Double(kEps); + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); auto result = batch_norm( input, @@ -15759,7 +15783,7 @@ TEST(NVFuserTest, FusionBNRepro_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBNRepro2_CUDA) { +TEST_F(NVFuserTest, FusionBNRepro2_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15777,8 +15801,8 @@ TEST(NVFuserTest, FusionBNRepro2_CUDA) { auto input = makeSymbolicTensor(numDims); fusion.addInput(input); - Val* momentum_ptr = new Double(kMomentum); - Val* eps_ptr = new Double(kEps); + Val* momentum_ptr = IrBuilder::create(kMomentum); + Val* eps_ptr = IrBuilder::create(kEps); auto result = batch_norm( input, @@ -15820,7 +15844,7 @@ TEST(NVFuserTest, FusionBNRepro2_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { +TEST_F(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15830,7 +15854,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { auto tv1 = makeConcreteTensor({0}); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(2.5)); + auto tv2 = add(tv0, IrBuilder::create(2.5)); fusion.addOutput(tv2); auto tv3 = makeConcreteTensor({0}); @@ -15846,7 +15870,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { auto lparams = schedulePointwise(&fusion, {input0, input1}); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}); fe.runFusion({input0, input1}, {cg_output2, cg_output3}, lparams); auto aten_output2 = input0.add(2.5); @@ -15861,7 +15885,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { +TEST_F(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15891,7 +15915,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}, lparams); auto cg_outputs = fe.runFusion({input0, input1}, lparams); auto aten_output2 = input0.sum({1}); at::Tensor aten_output3 = at::empty({0}, options); @@ -15907,7 +15931,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { lparams); } -TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { +TEST_F(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15938,7 +15962,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}, lparams); auto cg_outputs = fe.runFusion({input0, input1}, lparams); auto aten_output2 = input0.sum({0}).add(input0); at::Tensor aten_output3 = at::empty({0}, options); @@ -15954,7 +15978,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { lparams); } -TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) { +TEST_F(NVFuserTest, FusionSegmentIoAlias_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -15966,7 +15990,7 @@ TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) { fusion->addInput(tv1); fusion->addInput(tv2); - TensorView* tv3 = add(tv0, new Double(1)); // Group 0 + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 TensorView* tv4 = max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, @@ -16008,7 +16032,7 @@ TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) { executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWelford1Output_CUDA) { +TEST_F(NVFuserTest, FusionWelford1Output_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16028,7 +16052,7 @@ TEST(NVFuserTest, FusionWelford1Output_CUDA) { testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { +TEST_F(NVFuserTest, FusionTranslate1Welford_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16037,7 +16061,8 @@ TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { fusion->addInput(tv0); auto tvs = Welford(tv0, {1}); - fusion->addOutput(tvs.var_sum); + auto tv_out = add(tv0, broadcast(tvs.avg, {false, true})); + fusion->addOutput(tv_out); FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto run_test = [&executor_cache, @@ -16047,9 +16072,13 @@ TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { auto outputs = executor_cache.runFusionWithInputs({t0}); // Square sums does not fit well in the testValidate assumptions, // so we just compare the divided output here. - outputs[0] /= inner_size; - auto t1 = t0.var({1}, false); - testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate( + fusion, + outputs, + {t0}, + {t0.add(t0.mean({1}).unsqueeze(1))}, + __LINE__, + __FILE__); return executor_cache.getMostRecentKernelRuntime(); }; @@ -16057,21 +16086,25 @@ TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { // Run a translated welford auto runtime1 = run_test(64); // Check it was translated - TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 2); TORCH_CHECK( - runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() == - ScheduleHeuristic::Persistent); + runtime1->fusionSegments()->groups().size() == 1 && + runtime1->fusionSegments()->groups()[0]->exprs().size() > 2); // Run an un-translated welford auto runtime2 = run_test(65536); - // Check it was not translated - TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 1); - TORCH_CHECK( - runtime2->schedulerHeuristics()->singleKernelHeuristics()->heuristc() == - ScheduleHeuristic::Reduction); + + bool found_welford = false; + for (auto group : runtime2->fusionSegments()->groups()) { + for (auto expr : group->exprs()) { + if (expr->isA()) { + found_welford = true; + } + } + } + TORCH_CHECK(found_welford); } -TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { +TEST_F(NVFuserTest, FusionTranslate2Welford_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16080,10 +16113,12 @@ TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { fusion->addInput(tv0); auto tvs1 = Welford(tv0, {1}); - auto tvs2 = Welford(tv0, {1}); + auto tv_out1 = add(tv0, broadcast(tvs1.avg, {false, true})); + fusion->addOutput(tv_out1); - fusion->addOutput(tvs1.var_sum); - fusion->addOutput(tvs2.var_sum); + auto tvs2 = Welford(tv0, {1}); + auto tv_out2 = add(tv0, broadcast(tvs2.avg, {false, true})); + fusion->addOutput(tv_out2); FusionExecutorCache executor_cache(std::move(fusion_ptr)); @@ -16095,10 +16130,8 @@ TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { // Square sums does not fit well in the testValidate assumptions, // so we just compare the divided output here. - outputs[0] /= inner_size; - outputs[1] /= inner_size; - auto t1 = t0.var({1}, false); - testValidate(fusion, outputs, {t0}, {t1, t1}, __LINE__, __FILE__); + auto out = t0.add(t0.mean({1}).unsqueeze(1)); + testValidate(fusion, outputs, {t0}, {out, out}, __LINE__, __FILE__); return executor_cache.getMostRecentKernelRuntime(); }; @@ -16106,18 +16139,25 @@ TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { // Run a translated welford auto runtime1 = run_test(64); // Check it was translated - TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 4); TORCH_CHECK( - runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() == - ScheduleHeuristic::Persistent); + runtime1->fusionSegments()->groups().size() == 1 && + runtime1->fusionSegments()->groups()[0]->exprs().size() > 4); // Run an un-translated welford auto runtime2 = run_test(65536); // // Check it was not translated - TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 2); + bool found_welford = false; + for (auto group : runtime2->fusionSegments()->groups()) { + for (auto expr : group->exprs()) { + if (expr->isA()) { + found_welford = true; + } + } + } + TORCH_CHECK(found_welford); } -TEST(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { +TEST_F(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16150,7 +16190,7 @@ TEST(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { TORCH_CHECK(!runtime->isSegmented()); } -TEST(NVFuserTest, FusionWelfordOtherPersistence_CUDA) { +TEST_F(NVFuserTest, FusionWelfordOtherPersistence_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16176,20 +16216,22 @@ TEST(NVFuserTest, FusionWelfordOtherPersistence_CUDA) { at::Tensor t0 = at::randn({128, inner_size}, options); auto outputs = executor_cache.runFusionWithInputs({t0}); - auto t1 = t0.mean({1}).unsqueeze(1) + t0; - auto t2 = t0.sum({1}).unsqueeze(1) + t0; + auto t1 = t0.to(c10::kDouble).mean({1}).unsqueeze(1) + t0; + auto t2 = t0.to(c10::kDouble).sum({1}).unsqueeze(1) + t0; testValidate(fusion, outputs, {t0}, {t2, t1}, __LINE__, __FILE__); return executor_cache.getMostRecentKernelRuntime(); }; for (auto inner_size : {4096, 8192, 32768}) { - auto runtime = run_test(4096); - TORCH_CHECK(!runtime->isSegmented()); + auto runtime = run_test(inner_size); + TORCH_CHECK( + !runtime->isSegmented() || + runtime->fusionSegments()->groups().size() == 1); } } -TEST(NVFuserTest, FusionSegmentIslands_CUDA) { +TEST_F(NVFuserTest, FusionSegmentIslands_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16211,7 +16253,7 @@ TEST(NVFuserTest, FusionSegmentIslands_CUDA) { fusion_executor_cache.runFusionWithInputs({t0, t1}); } -TEST(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16247,7 +16289,7 @@ TEST(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { TORCH_CHECK(tv8->getMaxProducerPosition() == 2); } -TEST(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16267,7 +16309,7 @@ TEST(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { TORCH_CHECK(tv3->getMaxProducerPosition() == 2); } -TEST(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16286,7 +16328,7 @@ TEST(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { TORCH_CHECK(tv3->getMaxProducerPosition() == 3); } -TEST(NVFuserTest, FusionSimpleWarp_CUDA) { +TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16317,14 +16359,14 @@ TEST(NVFuserTest, FusionSimpleWarp_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleWarpPad_CUDA) { +TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16365,13 +16407,13 @@ TEST(NVFuserTest, FusionSimpleWarpPad_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { +TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16409,13 +16451,13 @@ TEST(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { auto at_output = input1.sum({1, 2}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSerialWarpReduction_CUDA) { +TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16450,13 +16492,13 @@ TEST(NVFuserTest, FusionSerialWarpReduction_CUDA) { auto at_output = input1.sum({1, 2}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTrivialWarpReduction_CUDA) { +TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16494,13 +16536,13 @@ TEST(NVFuserTest, FusionTrivialWarpReduction_CUDA) { auto at_output = input1.sum({1, 2, 3}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultipleDimBinding_CUDA) { +TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16548,7 +16590,7 @@ TEST(NVFuserTest, FusionMultipleDimBinding_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1, input2}); auto outputs = fe.runFusion({input1, input2}); testValidate( fusion.get(), @@ -16559,7 +16601,7 @@ TEST(NVFuserTest, FusionMultipleDimBinding_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionPadNoWarpReduce_CUDA) { +TEST_F(NVFuserTest, FusionPadNoWarpReduce_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16588,19 +16630,19 @@ TEST(NVFuserTest, FusionPadNoWarpReduce_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { +TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(2); fusion->addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); fusion->addOutput(tv2); @@ -16623,13 +16665,13 @@ TEST(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { auto at_output = (input1 + 1).sum({1}); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { +TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16673,13 +16715,13 @@ TEST(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegfaultReduction_CUDA) { +TEST_F(NVFuserTest, FusionSegfaultReduction_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -16698,7 +16740,7 @@ TEST(NVFuserTest, FusionSegfaultReduction_CUDA) { std::vector at_sum_axes; std::vector outer_reduction_axes; std::vector outer_broadcast_mask(numDims, false); - Val* N = new Double(1); + Val* N = IrBuilder::create(1); for (const auto axis : c10::irange(numDims)) { if (axis != 1) { outer_reduction_axes.push_back(axis); @@ -16728,16 +16770,16 @@ TEST(NVFuserTest, FusionSegfaultReduction_CUDA) { &fusion, outputs, inputs, {at_output0, at_output1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPredicateElimination_CUDA) { +TEST_F(NVFuserTest, FusionPredicateElimination_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); - auto tv3 = add(tv2, new Double(3)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); + auto tv3 = add(tv2, IrBuilder::create(3)); fusion.addOutput(tv3); @@ -16748,7 +16790,7 @@ TEST(NVFuserTest, FusionPredicateElimination_CUDA) { { GpuLower gpulw(&fusion); - TORCH_CHECK(!isPredicated(tv2, gpulw)); + TORCH_CHECK(!PredicatedChecker::isPredicated(tv2, gpulw)); } tv2->axis(1)->parallelize(ParallelType::Serial); @@ -16756,11 +16798,11 @@ TEST(NVFuserTest, FusionPredicateElimination_CUDA) { { GpuLower gpulw(&fusion); - TORCH_CHECK(isPredicated(tv2, gpulw)); + TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); } } -TEST(NVFuserTest, FusionForceFp16Simple_CUDA) { +TEST_F(NVFuserTest, FusionForceFp16Simple_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16798,53 +16840,55 @@ TEST(NVFuserTest, FusionForceFp16Simple_CUDA) { } } -TEST(NVFuserTest, FusionForceBf16Simple_CUDA) { +TEST_F(NVFuserTest, FusionForceBf16Simple_CUDA) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } + + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); - fusion->addInput(tv0); - fusion->addInput(tv1); + // Group 1 + auto tv2 = sum(tv0, {1}); + auto tv3 = broadcast(tv2, {false, true}); - // Group 1 - auto tv2 = sum(tv0, {1}); - auto tv3 = broadcast(tv2, {false, true}); + // Group 2 + auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast + auto tv5 = castOp(DataType::BFloat16, tv4); - // Group 2 - auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast - auto tv5 = castOp(DataType::BFloat16, tv4); + fusion->addOutput(tv5); - fusion->addOutput(tv5); + FusionExecutorCache fec(std::move(fusion_ptr)); - FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector shape{15, 16}; - std::vector shape{15, 16}; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - // Check the segmented edge is bf16 - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - TORCH_CHECK(edge_tv->getDataType() == DataType::BFloat16); - } - } else { - GTEST_SKIP(); + // Check the segmented edge is bf16 + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + TORCH_CHECK(edge_tv->getDataType() == DataType::BFloat16); } #else - GTEST_SKIP(); + GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; #endif } -TEST(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { +TEST_F(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16893,64 +16937,66 @@ TEST(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { } } -TEST(NVFuserTest, FusionForceBf16NotAllCast_CUDA) { +TEST_F(NVFuserTest, FusionForceBf16NotAllCast_CUDA) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); - fusion->addInput(tv0); - fusion->addInput(tv1); + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); - // Group 1 - auto tv3 = sum(tv0, {1}); - auto tv4 = broadcast(tv3, {false, true, false}); - auto tv5 = sum(tv0, {1}); + fusion->addInput(tv0); + fusion->addInput(tv1); - // Group 2 - auto tv6 = add(tv4, tv1); // edge tv4, expect cast - auto tv7 = castOp(DataType::BFloat16, tv6); + // Group 1 + auto tv3 = sum(tv0, {1}); + auto tv4 = broadcast(tv3, {false, true, false}); + auto tv5 = sum(tv0, {1}); - // Group 3 - auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast + // Group 2 + auto tv6 = add(tv4, tv1); // edge tv4, expect cast + auto tv7 = castOp(DataType::BFloat16, tv6); - fusion->addOutput(tv7); - fusion->addOutput(tv8); + // Group 3 + auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast - FusionExecutorCache fec(std::move(fusion_ptr)); + fusion->addOutput(tv7); + fusion->addOutput(tv8); - std::vector shape{16, 16, 16}; + FusionExecutorCache fec(std::move(fusion_ptr)); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - auto complete_fusion = segmented_fusion->completeFusion(); - - // Check that the edge that wasn't fp16 is the producer of the - // reduction op, i.e. tv8 = sum(tv5,{1});. - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - if (edge_tv->getDataType() == DataType::Float) { - auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); - TORCH_CHECK(consumer->isA()); - } + std::vector shape{16, 16, 16}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); + + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + auto complete_fusion = segmented_fusion->completeFusion(); + + // Check that the edge that wasn't fp16 is the producer of the + // reduction op, i.e. tv8 = sum(tv5,{1});. + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + if (edge_tv->getDataType() == DataType::Float) { + auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); + TORCH_CHECK(consumer->isA()); } - } else { - GTEST_SKIP(); } #else - GTEST_SKIP(); + GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; #endif } -TEST(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16961,10 +17007,10 @@ TEST(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); + auto tv2 = mul(tv0, IrBuilder::create(2)); auto tv3 = broadcast(tv2, {false, false, true}); auto tv4 = add(tv3, tv1); - auto tv5 = mul(tv4, new Double(3)); + auto tv5 = mul(tv4, IrBuilder::create(3)); fusion->addOutput(tv5); // t4 cannot inner re-use t2, because there's a broadcast @@ -16978,13 +17024,13 @@ TEST(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto outputs = fe.runFusion({in0, in1}); testValidate(fusion, outputs, {in0, in1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseStressTest_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16995,17 +17041,17 @@ TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); - auto tv3 = mul(tv0, new Double(3)); + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = mul(tv0, IrBuilder::create(3)); auto tv4 = mul(tv2, tv3); // Broadcast buffer can be reused through outer sharing auto tv5 = broadcast(tv4, {true, false, false}); - auto tv6 = mul(tv5, new Double(5)); + auto tv6 = mul(tv5, IrBuilder::create(5)); auto tv7 = mul(tv6, tv1); - auto tv8 = mul(tv7, new Double(7)); + auto tv8 = mul(tv7, IrBuilder::create(7)); // tv9 shouldn't alias to avoid buffer over-subscription auto tv9 = broadcast(tv4, {true, false, false}); - auto tv10 = mul(tv9, new Double(9)); + auto tv10 = mul(tv9, IrBuilder::create(9)); auto tv11 = add(tv5, tv9); fusion->addOutput(tv7); fusion->addOutput(tv11); @@ -17031,7 +17077,7 @@ TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { auto t10 = t9 * 9; auto t11 = t5 + t9; FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; auto outputs = fe.runFusion({in0, in1}); @@ -17039,7 +17085,7 @@ TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { testValidate(fusion, outputs, {in0, in1}, {t7, t11}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17048,12 +17094,12 @@ TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { fusion->addInput(tv0); - auto tv1 = mul(tv0, new Double(2)); - auto tv2 = mul(tv1, new Double(2)); - auto tv3 = mul(tv2, new Double(2)); - auto tv4 = mul(tv3, new Double(2)); - auto tv5 = mul(tv4, new Double(2)); - auto tv6 = mul(tv5, new Double(2)); + auto tv1 = mul(tv0, IrBuilder::create(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); + auto tv4 = mul(tv3, IrBuilder::create(2)); + auto tv5 = mul(tv4, IrBuilder::create(2)); + auto tv6 = mul(tv5, IrBuilder::create(2)); fusion->addOutput(tv6); @@ -17064,7 +17110,7 @@ TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { auto in0 = at::randn({256, 512}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0}); auto outputs = fe.runFusion({in0}); auto at_out = in0.mul(2).mul(2).mul(2).mul(2).mul(2).mul(2); @@ -17072,7 +17118,7 @@ TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17083,12 +17129,12 @@ TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); + auto tv2 = mul(tv0, IrBuilder::create(2)); auto tv3 = broadcast(tv2, {false, false, true}); auto tv4 = add(tv3, tv1); // T4 to be inner aliased first, and // shouldn't outer alias on top - auto tv5 = mul(tv4, new Double(3)); - auto tv6 = mul(tv5, new Double(3)); + auto tv5 = mul(tv4, IrBuilder::create(3)); + auto tv6 = mul(tv5, IrBuilder::create(3)); fusion->addOutput(tv6); tv0->computeAt(tv6, 1, ComputeAtMode::BestEffort); @@ -17098,7 +17144,7 @@ TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { auto in0 = at::randn({2, 2}, options); auto in1 = at::randn({2, 2, 2}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto outputs = fe.runFusion({in0, in1}); auto at_out = (in0.mul(2.0).unsqueeze(2) + in1).mul(3.0).mul(3.0); @@ -17106,7 +17152,7 @@ TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { testValidate(fusion, outputs, {in0, in1}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17116,8 +17162,8 @@ TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { fusion->addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = mul(tv1, new Double(2)); - auto tv3 = mul(tv2, new Double(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); fusion->addOutput(tv3); @@ -17134,7 +17180,7 @@ TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { auto in0 = at::randn({3, 3, 3}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0}); auto outputs = fe.runFusion({in0}); auto at_out = in0.sum(1).mul(2).mul(2); @@ -17142,7 +17188,7 @@ TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17151,9 +17197,9 @@ TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { fusion->addInput(tv0); - auto tv1 = mul(tv0, new Double(3)); - auto tv2 = mul(tv1, new Double(2)); - auto tv3 = mul(tv2, new Double(2)); + auto tv1 = mul(tv0, IrBuilder::create(3)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); // tv1 used till here, cannot be reused by tv2 or tv3 auto tv4 = mul(tv3, tv1); @@ -17165,7 +17211,7 @@ TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { auto in0 = at::randn({16, 16}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0}); auto cg_outputs = fe.runFusion({in0}); auto at_t0 = in0 * 3.0; @@ -17174,7 +17220,7 @@ TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { testValidate(fusion, cg_outputs, {in0}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17185,12 +17231,12 @@ TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); - auto tv3 = mul(tv0, new Double(3)); + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = mul(tv0, IrBuilder::create(3)); auto tv4 = mul(tv2, tv3); auto tv5 = broadcast(tv4, {false, false, true}); auto tv6 = mul(tv5, tv1); - auto tv7 = mul(tv6, new Double(7)); + auto tv7 = mul(tv6, IrBuilder::create(7)); fusion->addOutput(tv7); // tv6 shouldn't re-use t2 or t3 because of @@ -17202,7 +17248,7 @@ TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { auto in0 = at::randn({2, 2}, options); auto in1 = at::randn({2, 2, 2}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto outputs = fe.runFusion({in0, in1}); auto t2 = in0 * 2; @@ -17214,7 +17260,7 @@ TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { testValidate(fusion, outputs, {in0, in1}, {t7}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue970_CUDA) { +TEST_F(NVFuserTest, FusionIssue970_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17230,14 +17276,13 @@ TEST(NVFuserTest, FusionIssue970_CUDA) { tv1->split(1, 4); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); at::manual_seed(0); at::Tensor t0 = at::randn({nelm, nelm}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); auto ref = sum(t0, {1}).unsqueeze(-1).expand({nelm, nelm}) + t0; @@ -17246,15 +17291,15 @@ TEST(NVFuserTest, FusionIssue970_CUDA) { } // Reproducer of #1016 -TEST(NVFuserTest, FusionIssue1016_CUDA) { +TEST_F(NVFuserTest, FusionIssue1016_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); @@ -17262,15 +17307,15 @@ TEST(NVFuserTest, FusionIssue1016_CUDA) { tv2->split(-1, 8); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 10; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0 + 1 + 2; @@ -17279,13 +17324,13 @@ TEST(NVFuserTest, FusionIssue1016_CUDA) { } // Reproducer of #1021 -TEST(NVFuserTest, FusionIssue1021_CUDA) { +TEST_F(NVFuserTest, FusionIssue1021_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = broadcast(tv1, {false, true}); fusion.addOutput(tv2); @@ -17298,12 +17343,12 @@ TEST(NVFuserTest, FusionIssue1021_CUDA) { tv2->axis(0)->parallelize(ParallelType::TIDx); tv2->axis(1)->parallelize(ParallelType::Vectorize); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = (t0 + 1).unsqueeze(-1); @@ -17312,7 +17357,7 @@ TEST(NVFuserTest, FusionIssue1021_CUDA) { } // Reproducer of issue #1053 -TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { +TEST_F(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -17321,7 +17366,7 @@ TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { auto tv1 = sum(tv0, {0}); fusion->addOutput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv2); tv1->split(0, 8); @@ -17340,20 +17385,20 @@ TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { auto at_tv2 = input1 + 1; FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_tv1, at_tv2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap1_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(1); fusion->addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv1); fusion->addOutput(tv2); @@ -17366,25 +17411,22 @@ TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) { // actual values are not statically known GpuLower gpulw(fusion.get()); const auto& pdmap = gpulw.parallelDimensionMap(); - auto kir_tv1 = gpulw.lowerValue(tv1)->as(); - auto kir_tv2 = gpulw.lowerValue(tv2)->as(); - for (const auto i : c10::irange(kir_tv1->domain()->domain().size())) { - auto dom1 = kir_tv1->domain()->domain()[i]; - auto dom2 = kir_tv2->domain()->domain()[i]; + for (const auto i : c10::irange(tv1->domain()->domain().size())) { + auto dom1 = tv1->domain()->domain()[i]; + auto dom2 = tv2->domain()->domain()[i]; TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent())); } TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({32}, options); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( @@ -17396,7 +17438,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap2_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -17418,16 +17460,15 @@ TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) { const auto& pdmap = gpulw.parallelDimensionMap(); TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({11}, options); at::Tensor input2 = at::randn({11, 13}, options); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1, input2}); auto outputs = fe.runFusion({input1, input2}); auto ref = input1.unsqueeze(-1) + input2; @@ -17437,24 +17478,24 @@ TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) { } // Mix symbolic and concrete tensors -TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap3_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(1); fusion->addInput(tv0); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv2); - auto tv3 = add(tv0, new Double(1)); + auto tv3 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv3); tv2->split(0, 10); tv3->split(0, 20); - auto tv4 = add(tv0, new Double(1)); + auto tv4 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv4); - auto tv5 = add(tv0, new Double(1)); + auto tv5 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv5); // Not mapped but equal extent @@ -17471,19 +17512,18 @@ TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) { const auto& pdmap = gpulw.parallelDimensionMap(); TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); TORCH_CHECK( pdmap.get(ParallelType::TIDy)->isConst() && - pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); + pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( @@ -17496,7 +17536,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) { } // Parallelizing merged broadcast domains -TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17504,7 +17544,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { fusion.addInput(tv0); auto tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -17526,16 +17566,15 @@ TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { const auto& pdmap = gpulw.parallelDimensionMap(); TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); at::Tensor input2 = at::randn({15, 13}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); auto ref = (input1 + 1).unsqueeze(0) + input2; @@ -17543,7 +17582,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17570,18 +17609,17 @@ TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) { TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); TORCH_CHECK( pdmap.get(ParallelType::TIDx)->isConst() && - pdmap.get(ParallelType::TIDx)->as()->value().value() == 4); + pdmap.get(ParallelType::TIDx)->as()->value().value() == 4); TORCH_CHECK( - pdmap.get(ParallelType::TIDy)->isA() && - pdmap.get(ParallelType::TIDy)->as()->name() == - "blockDim.y"); + pdmap.get(ParallelType::TIDy)->isA() && + pdmap.get(ParallelType::TIDy)->as()->name() == "blockDim.y"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); at::Tensor input2 = at::randn({13, 15}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); auto ref = (input1).unsqueeze(-1) + input2; @@ -17589,7 +17627,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) { testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { +TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -17603,7 +17641,7 @@ TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { auto t13 = makeSymbolicTensor(3, DataType::Half); auto t15 = makeSymbolicTensor(3, DataType::Half); auto t17 = makeSymbolicTensor(3, DataType::Half); - auto d56 = new Double(); + auto d56 = IrBuilder::create(); fusion.addInput(t0); fusion.addInput(t1); @@ -17636,9 +17674,10 @@ TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { auto t29 = mul(t25, t23); auto t30 = sum(t29, {2}); auto t31 = broadcast(t30, {false, false, true}); - auto d59 = mul(t1->getRootDomain()[2]->extent(), new Double(1)); + auto d59 = + mul(t1->getRootDomain()[2]->extent(), IrBuilder::create(1)); auto t26 = mul(d59, t25); - auto txx = mul(t26, new Double(1)); + auto txx = mul(t26, IrBuilder::create(1)); auto t33 = sub(txx, t28); auto d70 = unaryOp(UnaryOpType::Reciprocal, d59); auto t35 = mul(d70, t6); @@ -17694,23 +17733,23 @@ TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { } } -TEST(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { +TEST_F(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); - auto tv3 = add(tv0, new Double(1)); - auto tv4 = add(tv3, new Double(1)); + auto tv3 = add(tv0, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); fusion.addOutput(tv4); - auto tv5 = add(tv0, new Double(1)); - auto tv6 = add(tv5, new Double(1)); + auto tv5 = add(tv0, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); fusion.addOutput(tv6); // Case 1: local memory tensor computed serially and used by @@ -17732,13 +17771,13 @@ TEST(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { tv5->axis(-1)->parallelize(ParallelType::TIDx); tv5->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int nx = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({nx}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = t0 + 2; @@ -17748,16 +17787,16 @@ TEST(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { } // Repro of issue #1105 -TEST(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { +TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -17783,12 +17822,12 @@ TEST(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { } } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 3; @@ -17796,24 +17835,24 @@ TEST(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1099_CUDA) { +TEST_F(NVFuserTest, FusionIssue1099_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); auto tv3 = makeSymbolicTensor(1); fusion.addInput(tv3); // Just to make TIDx/y/z non-exact - auto tv4 = add(tv3, new Double(1)); - auto tv5 = add(tv4, new Double(1)); - auto tv6 = add(tv5, new Double(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); fusion.addOutput(tv6); tv2->split(0, 4); @@ -17835,13 +17874,13 @@ TEST(NVFuserTest, FusionIssue1099_CUDA) { tv6->split(0, 7); tv6->axis(-1)->parallelize(ParallelType::TIDz); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t3 = at::randn({19}, options); std::vector aten_inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref_t2 = t0 + 2; @@ -17852,18 +17891,15 @@ TEST(NVFuserTest, FusionIssue1099_CUDA) { } // Repro of issue #1080 -TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionUnswitchPredicate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv2->split(0, 4); @@ -17883,14 +17919,14 @@ TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int nx = 4; const int ny = 10; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({nx, ny}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = t0 + 2; @@ -17898,7 +17934,7 @@ TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1189_CUDA) { +TEST_F(NVFuserTest, FusionIssue1189_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17926,12 +17962,12 @@ TEST(NVFuserTest, FusionIssue1189_CUDA) { parallelize(tv2); parallelize(tv3); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({16, 16, 1}, options); at::Tensor t1 = at::randn({16, 16, 1}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); auto outputs = fe.runFusion({t0, t1}); auto ref = (t0 + t1).sum({1}); @@ -17939,7 +17975,7 @@ TEST(NVFuserTest, FusionIssue1189_CUDA) { testValidate(&fusion, outputs, {t0, t1}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1052_CUDA) { +TEST_F(NVFuserTest, FusionIssue1052_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17948,10 +17984,10 @@ TEST(NVFuserTest, FusionIssue1052_CUDA) { auto tv1 = makeSymbolicTensor(1); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv2); - auto tv3 = add(tv1, new Double(1)); + auto tv3 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv3); tv2->axis(-1)->parallelize(ParallelType::TIDx); @@ -17960,13 +17996,13 @@ TEST(NVFuserTest, FusionIssue1052_CUDA) { scheduler_utils::parallelizeAllLike(tv2, {tv0}); scheduler_utils::parallelizeAllLike(tv3, {tv1}); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10}, options); at::Tensor t1 = at::randn({100}, options); std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref_t2 = t0 + 1; @@ -17977,7 +18013,7 @@ TEST(NVFuserTest, FusionIssue1052_CUDA) { } // Repro of issue #1115 -TEST(NVFuserTest, FusionPointwiseBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18002,7 +18038,7 @@ TEST(NVFuserTest, FusionPointwiseBroadcast_CUDA) { schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto at_x_add_bias = at_x + at_bias; @@ -18012,23 +18048,23 @@ TEST(NVFuserTest, FusionPointwiseBroadcast_CUDA) { testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSmemAliasSerial_CUDA) { +TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); // Just set the dimension of TIDx auto tv4 = makeSymbolicTensor(1); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); tv1->setMemoryType(MemoryType::Shared); @@ -18040,14 +18076,13 @@ TEST(NVFuserTest, FusionSmemAliasSerial_CUDA) { // TIDx. They should be predicated as they are redundant and can // interfere with smem aliasing (issue #1100). - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10}, options); - at::Tensor t4 = at::randn({1024}, options); std::vector aten_inputs = {t0, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 3; @@ -18056,14 +18091,14 @@ TEST(NVFuserTest, FusionSmemAliasSerial_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { +TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); auto tv2 = makeSymbolicTensor(1); @@ -18074,13 +18109,13 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { tv1->axis(0)->parallelize(ParallelType::TIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t2 = at::randn({19}, options); std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 1; @@ -18089,14 +18124,14 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); auto tv2 = makeSymbolicTensor(1); @@ -18107,13 +18142,13 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { tv1->axis(0)->parallelize(ParallelType::TIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t2 = at::randn({19}, options); std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 1; @@ -18122,7 +18157,7 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { +TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18134,12 +18169,12 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { auto tv2 = makeSymbolicTensor(3); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); auto tv4 = makeSymbolicTensor(3); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); tv1->axis(0)->parallelize(ParallelType::BIDx); @@ -18175,7 +18210,7 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { #endif } -TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18187,12 +18222,12 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { auto tv2 = makeSymbolicTensor(3); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); auto tv4 = makeSymbolicTensor(3); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); tvs.avg->axis(0)->parallelize(ParallelType::BIDx); @@ -18229,7 +18264,7 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { } // Repro of issue #1102 -TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { +TEST_F(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18237,18 +18272,18 @@ TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { fusion.addInput(tv0); // Just to make TIDx/y/z non-exact - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); auto tv4 = makeSymbolicTensor(1); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); - auto tv6 = add(tv5, new Double(1)); - auto tv7 = add(tv6, new Double(1)); - auto tv8 = add(tv7, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + auto tv7 = add(tv6, IrBuilder::create(1)); + auto tv8 = add(tv7, IrBuilder::create(1)); auto tv9 = sum(tv8, {0}); fusion.addOutput(tv9); @@ -18274,13 +18309,13 @@ TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { tv5->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t4 = at::randn({19}, options); std::vector aten_inputs = {t0, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 3; @@ -18290,8 +18325,9 @@ TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { } // Repro of #1102 and #1129 -TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } Fusion fusion; @@ -18302,16 +18338,16 @@ TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { auto tv1 = makeSymbolicTensor(1); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); - auto tv3 = add(tv2, new Double(1)); - auto tv4 = add(tv3, new Double(1)); - auto tv5 = add(tv4, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); // Just to make TIDx/y/z non-exact - auto tvx = add(tv1, new Double(1)); - auto tvy = add(tvx, new Double(1)); - auto tvz = add(tvy, new Double(1)); + auto tvx = add(tv1, IrBuilder::create(1)); + auto tvy = add(tvx, IrBuilder::create(1)); + auto tvz = add(tvy, IrBuilder::create(1)); fusion.addOutput(tvz); tv5->split(0, 4); @@ -18335,13 +18371,13 @@ TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { tv->setMemoryType(MemoryType::Shared); } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t1 = at::randn({19}, options); std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 4; @@ -18351,21 +18387,24 @@ TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { } // Repro of issue #1136 -TEST(NVFuserTest, FusionFloatPow_CUDA) { +TEST_F(NVFuserTest, FusionFloatPow_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = binaryOp(BinaryOpType::Pow, tv0, new Int(4)); + auto tv1 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(4)); // To check if pow(tv0, 2) is replaced with tv0 * tv0 - auto tv2 = binaryOp(BinaryOpType::Pow, tv0, new Int(2)); + auto tv2 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); // To check if pow(tv0, 2.0) is replaced with tv0 * tv0 - auto tv3 = binaryOp(BinaryOpType::Pow, tv0, new Double(2)); - auto tv4 = binaryOp(BinaryOpType::Pow, tv0, new Int(3)); - auto tv5 = binaryOp(BinaryOpType::Pow, tv0, new Double(3)); - auto s = binaryOp(BinaryOpType::Pow, new Double(3), new Double(3)); + auto tv3 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); + auto tv4 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); + auto tv5 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); + auto s = binaryOp( + BinaryOpType::Pow, + IrBuilder::create(3), + IrBuilder::create(3)); auto tv6 = add(tv0, s); fusion.addOutput(tv1); @@ -18382,14 +18421,14 @@ TEST(NVFuserTest, FusionFloatPow_CUDA) { TransformPropagator::from(tv1); scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); // Negative inputs cause nan in Fuesr as use_fast_math is enabled t0 = abs(t0); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto p4 = at::pow(t0, 4); @@ -18406,7 +18445,7 @@ TEST(NVFuserTest, FusionFloatPow_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionIssue1127_CUDA) { +TEST_F(NVFuserTest, FusionIssue1127_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18435,7 +18474,7 @@ TEST(NVFuserTest, FusionIssue1127_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -TEST(NVFuserTest, FusionChannelsLastParser_CUDA) { +TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // This test may not pass if using a custom block sync as there may // be additional calls. Skip the test as it's not specifically // relevant with block synchronizatin. @@ -18486,30 +18525,30 @@ TEST(NVFuserTest, FusionChannelsLastParser_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { if ((((((((((nvfuser_index_t)blockIdx.x) * 1) + 0) * 1) + 0) * 128) + ((nvfuser_index_t)threadIdx.x)) < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { - constexpr nvfuser_index_t ki674 = 0; + constexpr nvfuser_index_t i120 = 0; __half T9[1]; - constexpr nvfuser_index_t ki716 = 0; - T9[ki716] = 0; - constexpr nvfuser_index_t ki707 = 0; - T9[ki707] - = T2[((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * (((1 * T0.size[2]) * T0.size[1]) * T0.size[3])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * ((1 * T0.size[2]) * T0.size[1])) + (((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * (1 * T0.size[2])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3]) * 1)]; + constexpr nvfuser_index_t i132 = 0; + T9[i132] = 0; + constexpr nvfuser_index_t i128 = 0; + T9[i128] + = T2[((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * (((1 * T0.size[2]) * T0.size[1]) * T0.size[3])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * ((1 * T0.size[2]) * T0.size[1])) + (((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * (1 * T0.size[2])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3]) * 1)]; __half T8[1]; - constexpr nvfuser_index_t ki722 = 0; - T8[ki722] = 0; - constexpr nvfuser_index_t ki702 = 0; - T8[ki702] - = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki702) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; + constexpr nvfuser_index_t i134 = 0; + T8[i134] = 0; + constexpr nvfuser_index_t i126 = 0; + T8[i126] + = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i126) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; __half T10[1]; - constexpr nvfuser_index_t ki683 = 0; + constexpr nvfuser_index_t i124 = 0; float T3[1]; T3[0] - = __half2float(T9[ki683]); + = __half2float(T9[i124]); float T4[1]; T4[0] = T3[0]; float T1[1]; T1[0] - = __half2float(T8[ki683]); + = __half2float(T8[i124]); float T5[1]; T5[0] = T1[0] @@ -18517,11 +18556,11 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, float T6[1]; T6[0] = relu(T5[0]); - T10[ki683] + T10[i124] = __float2half(T6[0]); - constexpr nvfuser_index_t ki676 = 0; - T7[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki676) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] - = T10[ki676]; + constexpr nvfuser_index_t i122 = 0; + T7[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i122) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] + = T10[i122]; } } )"; @@ -18558,7 +18597,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, // TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { +TEST_F(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18566,8 +18605,8 @@ TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -18575,12 +18614,12 @@ TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { tv2->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::Unswitch); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10, 1024}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = sum(t0, {1}) + 2; @@ -18588,24 +18627,24 @@ TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonContigOutputs_CUDA) { +TEST_F(NVFuserTest, FusionNonContigOutputs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); tv1->setContiguity(false); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor at_input = at::randn({10}, options); at::Tensor at_output = at::empty_strided({10}, {2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_input}); auto returned_outputs = fe.runFusion({at_input}, {at_output}); // Returned outputs should only contain one tensor that is the same @@ -18619,7 +18658,7 @@ TEST(NVFuserTest, FusionNonContigOutputs_CUDA) { testValidate(&fusion, {at_output}, {at_input}, {at_ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTestWarpSoftMax_CUDA) { +TEST_F(NVFuserTest, FusionTestWarpSoftMax_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18654,14 +18693,15 @@ TEST(NVFuserTest, FusionTestWarpSoftMax_CUDA) { // Test result FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref_output = at::_softmax(aten_input, 1, false); testValidate(&fusion, outputs, aten_inputs, {ref_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1133_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionIssue1133_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } Fusion fusion; @@ -18670,9 +18710,9 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -18702,20 +18742,20 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { // There should be no allocation other than those for tv1 and tv2 TORCH_CHECK(false, "Invalid allocation detected"); } - TORCH_CHECK(size->isA(), "Invalid allocation size"); - TORCH_CHECK(size->as()->isConst(), "Allocation not constant"); - auto size_int = size->as()->value().value(); + TORCH_CHECK(size->isA(), "Invalid allocation size"); + TORCH_CHECK(size->as()->isConst(), "Allocation not constant"); + auto size_int = size->as()->value().value(); if (alloc->buffer()->name() == 1) { TORCH_CHECK( size_int == split_factor, "Invalid allocation size: ", - size->as()->value().value()); + size->as()->value().value()); tv1_validated = true; } else { TORCH_CHECK( size_int == 1, "Invalid allocation size: ", - size->as()->value().value()); + size->as()->value().value()); tv2_validated = true; } } @@ -18724,12 +18764,12 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { TORCH_CHECK(tv1_validated, "Failed to validate tv1 allocation"); TORCH_CHECK(tv2_validated, "Failed to validate tv2 allocation"); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({99, 101}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = (t0 + 1).sum({1}) + 1; @@ -18737,7 +18777,7 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionRfactorContigIDs_CUDA) { +TEST_F(NVFuserTest, FusionRfactorContigIDs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18756,12 +18796,12 @@ TEST(NVFuserTest, FusionRfactorContigIDs_CUDA) { tv2->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({99, 101}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = t0.sum({1}); @@ -18769,7 +18809,7 @@ TEST(NVFuserTest, FusionRfactorContigIDs_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18831,7 +18871,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { aten_t0.size(1) * dataTypeSize(DataType::Float)); } -TEST(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18894,7 +18934,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { aten_t0.size(1) * dataTypeSize(DataType::Half)); } -TEST(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18979,7 +19019,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { (dataTypeSize(DataType::Half) + dataTypeSize(DataType::Float))); } -TEST(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19056,10 +19096,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { aten_t0.size(1) * dataTypeSize(DataType::Half)); } -TEST(NVFuserTest, PersistentBufferProjection_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionPersistentBufferProjection_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -19107,8 +19144,9 @@ TEST(NVFuserTest, PersistentBufferProjection_CUDA) { testValidate(&fusion, cg_outputs, {aten_t0}, {aten_t7}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1223_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionIssue1223_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } Fusion fusion; @@ -19117,11 +19155,11 @@ TEST(NVFuserTest, FusionIssue1223_CUDA) { auto tv0 = makeContigTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0, 1}); fusion.addOutput(tv2); - auto tv3 = add(tv0, new Double(0)); + auto tv3 = add(tv0, IrBuilder::create(0)); fusion.addOutput(tv3); tv2->split(0, 4); @@ -19153,7 +19191,7 @@ TEST(NVFuserTest, FusionIssue1223_CUDA) { at::Tensor at_t0 = at::ones({11, 10}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {at_t0}); auto cg_outputs = fe.runFusion({at_t0}); auto at_t1 = (at_t0 + 1).sum(); @@ -19163,14 +19201,14 @@ TEST(NVFuserTest, FusionIssue1223_CUDA) { } // See #1247 and #1250 -TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { +TEST_F(NVFuserTest, FusionRfactorPredication1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = min(tv1, {0}); fusion.addOutput(tv2); @@ -19179,7 +19217,7 @@ TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { auto tv3 = makeContigTensor(1); fusion.addInput(tv3); - auto tv4 = add(tv3, new Double(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); fusion.addOutput(tv4); tv2->split(0, 4); @@ -19197,7 +19235,7 @@ TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { at::Tensor at_t3 = at::randn({128}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {at_t0, at_t3}); auto cg_outputs = fe.runFusion({at_t0, at_t3}); auto at_t2 = (at_t0 + 1).min(); @@ -19207,7 +19245,7 @@ TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { +TEST_F(NVFuserTest, FusionRfactorPredication2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19221,7 +19259,7 @@ TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { auto tv2 = makeContigTensor(1); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); tv1->split(0, 4); @@ -19250,7 +19288,7 @@ TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { at::Tensor at_t3 = at::randn({128}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {at_t0, at_t3}); auto cg_outputs = fe.runFusion({at_t0, at_t3}); auto at_t2 = std::get<0>(at_t0.min(0)); @@ -19260,7 +19298,7 @@ TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19292,7 +19330,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, "Only tv1 should have a non-divisible predicate."); - for (auto tv : {tv1}) { + for (auto tv : {loweredTv(tv1, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19309,7 +19347,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = t0.sum(); @@ -19318,14 +19356,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { } // Repro of issue #1074 -TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv2->split(0, 2); @@ -19346,7 +19384,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, "Only tv2 should have a non-divisible predicate."); - for (auto tv : {tv2}) { + for (auto tv : {loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19363,7 +19401,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { at::Tensor t0 = at::randn({13, 17}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = t0 + 2; @@ -19372,14 +19410,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { } // Similar to FusionNonDivisibleSplit1 but with unswitch -TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0}); fusion.addOutput(tv2); @@ -19397,7 +19435,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {tv1, tv2}) { + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19414,7 +19452,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19423,14 +19461,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { } // Non-divisible split through merge -TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0, 1}); fusion.addOutput(tv2); @@ -19447,7 +19485,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {tv1, tv2}) { + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19464,7 +19502,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { at::Tensor t0 = at::randn({24, 2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19473,14 +19511,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { } // Nested splits -TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0}); fusion.addOutput(tv2); @@ -19501,7 +19539,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {tv1, tv2}) { + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19518,7 +19556,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19527,7 +19565,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { } // Vectorized non-divisible split. Must be validated at run time -TEST(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19556,13 +19594,12 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { splits_to_predicate); } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); - auto t0 = at::randn({32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = t0; @@ -19576,7 +19613,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { } // If a split is validated at run time, it's not necessary to predicate. -TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19584,7 +19621,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { fusion.addInput(tv0); auto tv1 = set(tv0); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = sum(tv2, {0}); fusion.addOutput(tv3); @@ -19611,13 +19648,13 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { splits_to_predicate); } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); auto t0 = at::randn({1024}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19625,6 +19662,784 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionIssue1284Repro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector input_shape_0 = {10, 20}; + std::vector input_shape_1 = {15}; + + TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); + TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); + fusion.addInput(in_0); + fusion.addInput(in_1); + + TensorView* out_0 = add(in_0, IrBuilder::create(0.f)); + TensorView* out_1 = add(in_1, IrBuilder::create(2.f)); + + fusion.addOutput(out_0); + fusion.addOutput(out_1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_in_0 = at::randn(input_shape_0, options); + at::Tensor at_in_1 = at::randn(input_shape_1, options); + std::vector aten_inputs = {at_in_0, at_in_1}; + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + auto t1 = at_in_1 + 2; + + auto runtime = fec.getMostRecentKernelRuntime(); + TORCH_INTERNAL_ASSERT(runtime->isSegmented()); + TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); + + testValidate( + &fusion, outputs, {at_in_0, at_in_1}, {at_in_0, t1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1284Repro2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector input_shape_0 = {4, 4}; + std::vector input_shape_1 = {3, 4, 4}; + std::vector input_shape_2 = {2, 8, 4, 4}; + + TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); + TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); + TensorView* in_2 = makeSymbolicTensor(input_shape_2.size()); + + fusion.addInput(in_0); + fusion.addInput(in_1); + fusion.addInput(in_2); + + TensorView* out_0 = add(in_0, in_1); + TensorView* out_1 = add(in_0, in_2); + + fusion.addOutput(out_0); + fusion.addOutput(out_1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_in_0 = at::randn(input_shape_0, options); + at::Tensor at_in_1 = at::randn(input_shape_1, options); + at::Tensor at_in_2 = at::randn(input_shape_2, options); + + std::vector aten_inputs = {at_in_0, at_in_1, at_in_2}; + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + auto t0 = at_in_0 + at_in_1; + auto t1 = at_in_0 + at_in_2; + + auto runtime = fec.getMostRecentKernelRuntime(); + TORCH_INTERNAL_ASSERT(runtime->isSegmented()); + TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); + + testValidate( + &fusion, + outputs, + {at_in_0, at_in_1, at_in_2}, + {t0, t1}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1305Repro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto t0 = makeContigTensor(1); + auto t1 = makeContigTensor(2); + + fusion.addInput(t0); + fusion.addInput(t1); + + auto t2 = broadcast(t0, {true, false}); + auto t3 = add(t1, t2); + auto t4 = add(t3, t2); + auto t5 = sum(t4, {1}); + auto t6 = broadcast(t5, {false, true}); + auto t7 = add(t3, t6); + + fusion.addOutput(t7); + + t3->computeAt(t7, -1, ComputeAtMode::MostInlined); + + TORCH_INTERNAL_ASSERT(t3->getComputeAtPosition() == 1); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 1); + + tv3->axis(-2)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, -1); + + tv3->axis(-2)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 1); + + // tv2 is invalid to double-buffer as its producer, tv1, is + // computed inside the double-buffering loop. + ASSERT_ANY_THROW(tv2->doubleBuffer()); + + // Moving tv2 inner makes tv1 large enough to double-buffer tv2 + tv2->computeAt(tv3, 2); + + tv2->doubleBuffer(); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering smem to local and unswitch +TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + tv3->split(-1, 8); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 2); + tv2->computeAt(tv3, -1); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::Unswitch); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + tv2->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering gmem to shared and unswitch +TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + + tv2->split(-1, 128); + tv2->split(-1, 32); + tv2->split(-1, 8); + TransformPropagator::from(tv2); + + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, -1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Unswitch); + scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering smem to local and unroll +TEST_F(NVFuserTest, FusionDoubleBuffering6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 16); + tv3->split(-2, 4); + tv3->split(-2, 2); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 1); + tv2->computeAt(tv3, -1); + + tv3->axis(2)->parallelize(ParallelType::Unroll); + tv3->axis(4)->parallelize(ParallelType::TIDx); + + tv2->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({199}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering and vectorize +TEST_F(NVFuserTest, FusionDoubleBuffering7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv2->split(-1, 128); + tv2->split(-1, 4); + TransformPropagator::from(tv2); + + tv1->computeAt(tv2, 2); + + tv2->axis(-2)->parallelize(ParallelType::TIDx); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({200}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Multiple tensors to double-buffer +TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(1); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv4->split(0, 32); + tv4->split(0, 4); + TransformPropagator::from(tv4); + + tv0->computeAt(tv4, 1); + tv1->computeAt(tv4, 1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion)); + + tv2->doubleBuffer(); + tv3->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({100}, options); + auto t1 = at::randn({100}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Nested double buffering from gmem to smem and smem to register +TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto out = tv1; + fusion.addOutput(out); + + auto tv2 = tv0->cache_after(); + auto tv3 = tv2->cache_after(); + + out->split(0, 32); + out->split(0, 4); + TransformPropagator::from(out); + + tv2->setMemoryType(MemoryType::Shared); + + tv2->computeAt(out, 1); + tv3->computeAt(out, -1); + + out->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + tv2->doubleBuffer(); + tv3->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1001}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// FusionSmemBlockGemmCache + double buffering at both smem and local +TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(2); // (M, K) + TensorView* tv1 = makeSymbolicTensor(2); // (K, N) + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + TensorView* tv6 = tv5->cache_before(); + + // For smem double buffering + auto tv0_cache_local = tv0->cache_after(); + auto tv1_cache_local = tv1->cache_after(); + + // For register double buffering + auto tv0_cache_smem = tv0->cache_after(); + auto tv1_cache_smem = tv1->cache_after(); + + const int BSX = 32; + const int TSX = 8; + + // [M, K, N] + tv6->split(-1, BSX); + tv6->split(-1, TSX); + tv6->split(1, BSX); + tv6->split(0, BSX); + tv6->split(1, TSX); + // [M/BSX, BSX/TSX, TSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->reorder( + {{4, 7}, {7, 6}, {6, 5}, {2, 4}, {1, 3}, {3, 2}, {5, 1}, {0, 0}}); + // [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + + auto tv6_rf = tv6->rFactor({-1}); + + TransformPropagator::from(tv6_rf); + + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + + tv6_rf->computeAt(tv6, -1); + tv0_cache_local->computeAt(tv6_rf, -1); + tv1_cache_local->computeAt(tv6_rf, -1); + + tv0_cache_smem->setMemoryType(MemoryType::Shared); + tv1_cache_smem->setMemoryType(MemoryType::Shared); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-3)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); + + tv0_cache_local->doubleBuffer(); + tv1_cache_local->doubleBuffer(); + + tv0_cache_smem->doubleBuffer(); + tv1_cache_smem->doubleBuffer(); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { + auto mem_types = {MemoryType::Shared, MemoryType::Local}; + + for (auto mem_type : mem_types) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(mem_type); + + tv3->split(-1, 4); + TransformPropagator::from(tv3); + + tv1->computeAt(tv3, -2); + + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({15}, options); + FusionExecutor fe; + fe.compileFusion(&fusion); + + // This should throw an exception as the extent of t0 is not + // divisible by the vector width + ASSERT_ANY_THROW(fe.runFusion({t0})); + + auto t1 = at::randn({16}, options); + auto cg_outputs = fe.runFusion({t1}); + + auto ref = t1; + + testValidate(&fusion, cg_outputs, {t1}, {ref}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({10, 1}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({10, 20}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({10, 10}); + fusion.addInput(tv2); + + // Not concretized + auto tv3 = sum(tv2, {1}); + auto tv4 = broadcast(tv3, {false, true}); + auto tv5 = add(tv0, tv4); + fusion.addOutput(tv5); + + // Concretized + auto tv6 = sum(tv2, {1}); + auto tv7 = broadcast(tv6, {false, true}); + auto tv8 = add(tv1, tv7); + fusion.addOutput(tv8); + + for (auto tv : {tv3, tv4, tv5, tv6, tv7, tv8}) { + tv->axis(1)->parallelize(ParallelType::TIDx); + } + + GpuLower gpulw(&fusion); + TORCH_CHECK(!gpulw.concretizedBroadcastDomains().isConcretized( + loweredTv(tv4, gpulw)->axis(1))); + TORCH_CHECK(gpulw.concretizedBroadcastDomains().isConcretized( + loweredTv(tv7, gpulw)->axis(1))); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({10, 1}, options); + auto t1 = at::randn({10, 20}, options); + auto t2 = at::randn({10, 10}, options); + std::vector aten_inputs = {t0, t1, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t5 = t0 + t2.sum({1}).unsqueeze(-1); + auto t8 = t1 + t2.sum({1}).unsqueeze(-1); + + testValidate(&fusion, outputs, aten_inputs, {t5, t8}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 1}); + auto tv2 = broadcast(tv1, {true}); + auto tv3 = broadcast(tv2, {false, true}); + fusion.addOutput(tv3); + + // tv1 is thread-predicated with TIDx and TIDy + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::TIDy); + // tv2 broadcasts along TIDx + tv2->axis(0)->parallelize(ParallelType::TIDx); + // tv3 broadcasts along TIDy + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDy); + + // Both tv2 and tv3 broadcast along predicated TID dimensions, but + // since the broadcast domains are not concretized, there should be + // no actual parallel broadcast + + GpuLower gpulw(&fusion); + TORCH_CHECK( + !gpulw.kernel()->summary().has_block_broadcasts && + !gpulw.kernel()->summary().has_grid_broadcasts, + "There must be no parallel broadcast in this fusion"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({10, 11}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t3 = t0.sum().unsqueeze(-1).unsqueeze(-1); + + testValidate(&fusion, outputs, aten_inputs, {t3}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape({10, 4, 8}); + std::vector output_shape({8, 4, 1}); + + auto tv0 = makeConcreteTensor(input_shape); + fusion.addInput(tv0); + + auto tv2 = sum(tv0, {0}); + auto tv3 = set(tv2); + auto tv4 = + view(tv3, {input_shape.begin() + 1, input_shape.end()}, output_shape); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + // The view op adds a broadcast domain in tv4, which is + // parallelized. Howver, it is never materialized, so there should + // be no parallel broadcast. + + GpuLower gpulw(&fusion); + TORCH_CHECK( + !gpulw.kernel()->summary().has_block_broadcasts && + !gpulw.kernel()->summary().has_grid_broadcasts, + "There must be no parallel broadcast in this fusion"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(input_shape, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t5 = at::native::view(t0.sum(0), output_shape) + 1; + + testValidate(&fusion, outputs, aten_inputs, {t5}, __LINE__, __FILE__); +} + +// Merging non-broadcast and broadcast domains +// TODO: Fix use case see issue https://github.com/csarofeen/pytorch/issues/1418 +// validateParallelize does not pass. Even if it's skipped, +// generated code is invalid as blockBroadcast is not used. +#if 0 +TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + fusion.addOutput(tv3); + + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv2->merge(0, 1); + tv2->axis(0)->parallelize(ParallelType::TIDx); + // TODO: When set to shared memory, this kernel should be correct, but fails + // validation and when skipped produces incorrect code + tv2->setMemoryType(MemoryType::Shared); + + tv3->merge(0, 1); + tv3->axis(0)->parallelize(ParallelType::TIDx); + + fusion.printMath(); + fusion.printKernel(); +} +#endif + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/test/cpp/jit/test_gpu_shift.cpp b/test/cpp/jit/test_gpu_shift.cpp index 71fa156c2d2..2665f16563b 100644 --- a/test/cpp/jit/test_gpu_shift.cpp +++ b/test/cpp/jit/test_gpu_shift.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -18,8 +19,6 @@ #include #include #include -#include -#include #include #include #include @@ -82,33 +81,38 @@ void checkIntValue( void checkIntValue( kir::ExpressionEvaluator& evaluator, - const kir::Val* val, - kir::Int::ScalarType expected_value) { + const Val* val, + Int::ScalarType expected_value) { const auto actual_value = evaluator.evaluate(val); TORCH_CHECK(actual_value.has_value()); TORCH_CHECK(actual_value.value() == expected_value); } +// Used to signify invalid ranges, i.e., values at offset 0 to +// start_offset, and values at offset stop_offset to the end of the +// domain. +static constexpr int invalid_marker = 1; + // ATen version of tensor shifting auto shift( at::Tensor tensor, const std::vector& offsets, - std::vector strides = {}) { + std::vector padding = {}) { TORCH_INTERNAL_ASSERT(tensor.ndimension() == offsets.size()); - if (strides.empty()) { - strides = std::vector(tensor.ndimension(), 1); + if (padding.empty()) { + padding = offsets; + for (auto& p : padding) { + p = std::abs(p); + } } at::Tensor t = tensor; - std::vector stride_indices; for (size_t i = 0; i < offsets.size(); ++i) { - auto stride = strides[i]; - stride_indices.push_back( - at::indexing::Slice(0, at::indexing::None, stride)); - const auto offset = offsets[i]; + auto offset = offsets[i]; + t = t.roll(offsets[i], i); if (offset == 0) { continue; } - t = t.roll(offsets[i], i); + // Zero padding std::vector indices( tensor.ndimension(), at::indexing::Slice(0, at::indexing::None)); if (offset > 0) { @@ -117,8 +121,20 @@ auto shift( indices[i] = at::indexing::Slice(offset, at::indexing::None); } t.index(indices) = 0; + // Fill the outside range by the special marker value. + const auto pad = padding[i]; + if (offset > 0) { + indices[i] = at::indexing::Slice(0, offset - pad); + } else { + offset += pad; + TORCH_INTERNAL_ASSERT(offset <= 0); + if (offset == 0) { + continue; + } + indices[i] = at::indexing::Slice(offset, at::indexing::None); + } + t.index(indices) = invalid_marker; } - t = t.index(stride_indices); return t; } @@ -153,13 +169,28 @@ auto gather( TORCH_CHECK(w_size != 0); const auto& pad = pad_width[i]; TORCH_CHECK(pad.size() == 2); + const auto out_extent_adj = -w_size + 1 + pad[0] + pad[1]; + TORCH_INTERNAL_ASSERT(out_extent_adj <= 0); + const auto stride = strides[i]; + TORCH_CHECK(stride >= 1); + at::Tensor concat_tensor; + for (int w = 0; w < w_size; ++w) { std::vector shift_offsets(t.ndimension(), 0); shift_offsets[i] = pad[0] - w; - std::vector shift_strides(t.ndimension(), 1); - shift_strides[i] = strides[i]; - auto shifted = shift(t, shift_offsets, shift_strides); + auto shifted = shift(t, shift_offsets); + // Apply stride + if (stride != 1) { + std::vector indices( + shifted.ndimension(), at::indexing::Slice(0, at::indexing::None)); + if (out_extent_adj == 0) { + indices[i] = at::indexing::Slice(0, at::indexing::None, strides[i]); + } else { + indices[i] = at::indexing::Slice(0, out_extent_adj, strides[i]); + } + shifted = shifted.index(indices); + } shifted = shifted.unsqueeze(-1); if (w == 0) { concat_tensor = shifted; @@ -169,13 +200,32 @@ auto gather( } t = concat_tensor; } + + // Fill invalid regions with the marker. Note that when non-unit + // stride is used, it trims invalid regions, so no marking is + // necessary. + for (size_t i = 0; i < window_shape.size(); ++i) { + if (strides[i] != 1) { + continue; + } + + const auto out_extent_adj = + -window_shape[i] + 1 + pad_width[i][0] + pad_width[i][1]; + if (out_extent_adj < 0) { + std::vector indices( + t.ndimension(), at::indexing::Slice(0, at::indexing::None)); + indices[i] = at::indexing::Slice(out_extent_adj, at::indexing::None); + t.index(indices) = invalid_marker; + } + } + return t; } } // namespace // Shift an input tensor -TEST(NVFuserTest, FusionShift1_CUDA) { +TEST_F(NVFuserTest, FusionShift1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -202,7 +252,7 @@ TEST(NVFuserTest, FusionShift1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = shift(t0, {-1, 0}); @@ -219,19 +269,19 @@ TEST(NVFuserTest, FusionShift1_CUDA) { } // Shifts an intermediate tensor -TEST(NVFuserTest, FusionShift2_CUDA) { +TEST_F(NVFuserTest, FusionShift2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {-1, 0}); fusion.addOutput(tv2); // make it a little more complex - auto tv3 = add(tv0, new Double(3)); - auto tv4 = add(tv3, new Double(4)); + auto tv3 = add(tv0, IrBuilder::create(3)); + auto tv4 = add(tv3, IrBuilder::create(4)); auto tv5 = shift(tv4, {-1, 0}); auto tv6 = shift(tv4, {0, -1}); auto tv7 = shift(tv4, {1, 0}); @@ -250,21 +300,22 @@ TEST(NVFuserTest, FusionShift2_CUDA) { // t4 allocation: (t3.size[0] + 2) * (t3.size[1] + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 3 || tensor_name == 4) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { if (tensor_name == 1 && i == 1) { - TORCH_CHECK(alloc->shape().at(i)->isA()); + TORCH_CHECK(alloc->shape().at(i)->isA()); continue; } auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add); - TORCH_CHECK(def->as()->lhs()->isA()); - auto rhs = dynamic_cast(def->as()->rhs()); + dynamic_cast(alloc->shape().at(i)->definition()); + TORCH_CHECK( + def != nullptr && def->getBinaryOpType() == BinaryOpType::Add); + TORCH_CHECK(def->as()->lhs()->isA()); + auto rhs = dynamic_cast(def->as()->rhs()); TORCH_CHECK(rhs != nullptr && rhs->isConst()); int rhs_value = *rhs->value(); if (tensor_name == 1) { @@ -290,7 +341,7 @@ TEST(NVFuserTest, FusionShift2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -309,14 +360,14 @@ TEST(NVFuserTest, FusionShift2_CUDA) { testValidate(&fusion, outputs, inputs, {t2, t11}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { +TEST_F(NVFuserTest, FusionShiftRightOfCA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); fusion.addOutput(tv2); @@ -324,15 +375,15 @@ TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { tv1->setMemoryType(MemoryType::Global); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 100; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -341,16 +392,16 @@ TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { TORCH_CHECK(t2.allclose(outputs[0])); } -TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) { +TEST_F(NVFuserTest, FusionShiftLeftOfCA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = shift(tv2, {-1, 0}); - auto tv4 = add(tv3, new Double(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); fusion.addOutput(tv4); tv0->computeAt(tv4, -1); @@ -360,13 +411,13 @@ TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -TEST(NVFuserTest, FusionShiftSplit1_CUDA) { +TEST_F(NVFuserTest, FusionShiftSplit1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); auto tv3 = shift(tv1, {0, -2}); fusion.addOutput(tv2); @@ -379,35 +430,29 @@ TEST(NVFuserTest, FusionShiftSplit1_CUDA) { tv0->computeAt(tv2, -2); tv0->computeAt(tv3, -2); - // t1 allocation: (4 + 3) + // t1 allocation: 7 GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 3); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && size->value().value() == 7); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 9; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -417,23 +462,23 @@ TEST(NVFuserTest, FusionShiftSplit1_CUDA) { testValidate(&fusion, outputs, inputs, {t2, t3}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftSplit2_CUDA) { +TEST_F(NVFuserTest, FusionShiftSplit2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = shift(tv2, {0, -1}); auto tv4 = shift(tv2, {0, 1}); auto tv5 = add(tv3, tv4); fusion.addOutput(tv5); - auto tv6 = add(tv0, new Double(1)); + auto tv6 = add(tv0, IrBuilder::create(1)); auto tv7 = shift(tv6, {0, 0}); - auto tv8 = add(tv7, new Double(1)); + auto tv8 = add(tv7, IrBuilder::create(1)); fusion.addOutput(tv8); int split_factor = 4; @@ -444,26 +489,20 @@ TEST(NVFuserTest, FusionShiftSplit2_CUDA) { tv0->computeAt(tv5, -2); tv0->computeAt(tv8, -2); - // t1 and t2 allocation: (4 + 2) - // t4 allocation: (4) + // t1 and t2 allocation: 6 + // t4 allocation: 4 GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && size->value().value() == 6); } else if (tensor_name == 4) { TORCH_CHECK(alloc->shape().size() == 1); - auto size = dynamic_cast(alloc->shape().at(0)); + auto size = dynamic_cast(alloc->shape().at(0)); TORCH_CHECK(size != nullptr && size->isConst()); int size_value = *size->value(); TORCH_CHECK(size_value == split_factor); @@ -471,15 +510,15 @@ TEST(NVFuserTest, FusionShiftSplit2_CUDA) { } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 9; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 2; @@ -494,14 +533,14 @@ TEST(NVFuserTest, FusionShiftSplit2_CUDA) { testValidate(&fusion, outputs, inputs, {t5, t8}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { +TEST_F(NVFuserTest, FusionShiftDoubleSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {0, 1}); fusion.addOutput(tv3); @@ -518,35 +557,29 @@ TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { // t2: [i1, i2/8, 8] // t3: [i1, i2/8, 8] - // t1 and t2 allocation: (split_factor1 + 1) + // t1 and t2 allocation: (split_factor1 + 1) = 9 GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && size->value().value() == 9); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 3; @@ -555,7 +588,7 @@ TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { +TEST_F(NVFuserTest, FusionShift3ptStencil_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -576,7 +609,7 @@ TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -598,32 +631,29 @@ TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { // cache allocation: (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == cache->name()) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor + 2); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); + cache->doubleBuffer(); int numel_x = 99; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; @@ -631,7 +661,7 @@ TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencil_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -651,7 +681,7 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -672,28 +702,22 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { // cache allocation: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == cache->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor[i] + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); + cache->doubleBuffer(); int numel_x = 99; int numel_y = 101; @@ -701,6 +725,9 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -712,7 +739,7 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { +TEST_F(NVFuserTest, FusionShift9ptStencil_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -740,7 +767,7 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -763,28 +790,22 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { // cache allocation: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == cache->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor[i] + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); + cache->doubleBuffer(); int numel_x = 99; int numel_y = 101; @@ -792,6 +813,9 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -803,13 +827,13 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { +TEST_F(NVFuserTest, FusionShiftSmemBlocking_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); fusion.addOutput(tv2); @@ -826,35 +850,30 @@ TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { // tv1 allocation: (split_factor + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == tv1->name()) { TORCH_CHECK(alloc->shape().size() == 1); for (int i = 0; i < 1; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == smem_block_factor && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == smem_block_factor + 1); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 100; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -864,7 +883,7 @@ TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { +TEST_F(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -881,7 +900,7 @@ TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -902,14 +921,16 @@ TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { tv_out->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); + tv0_cache->doubleBuffer(); int numel_x = 99; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; @@ -917,7 +938,7 @@ TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -937,7 +958,7 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -965,15 +986,15 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-2)->parallelize(ParallelType::TIDy); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -985,13 +1006,13 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftMerge1_CUDA) { +TEST_F(NVFuserTest, FusionShiftMerge1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {-1, 1}); fusion.addOutput(tv2); @@ -1006,35 +1027,30 @@ TEST(NVFuserTest, FusionShiftMerge1_CUDA) { // t1 allocation: (split_factor + 1) * (split_factor + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor + 1); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1044,13 +1060,13 @@ TEST(NVFuserTest, FusionShiftMerge1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftMerge2_CUDA) { +TEST_F(NVFuserTest, FusionShiftMerge2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}); auto tv3 = shift(tv1, {-1, 1}); auto tv4 = add(tv2, tv3); @@ -1067,35 +1083,30 @@ TEST(NVFuserTest, FusionShiftMerge2_CUDA) { // t1 allocation: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1106,14 +1117,14 @@ TEST(NVFuserTest, FusionShiftMerge2_CUDA) { TORCH_CHECK(t4.allclose(outputs[0])); } -TEST(NVFuserTest, FusionShiftGlobal_CUDA) { +TEST_F(NVFuserTest, FusionShiftGlobal_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); auto tv3 = shift(tv1, {-1, 0}); auto tv4 = add(tv2, tv3); @@ -1132,17 +1143,18 @@ TEST(NVFuserTest, FusionShiftGlobal_CUDA) { // t1 allocation: (t1.size[0] + 1) * (t1.size[1] + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add); - TORCH_CHECK(def->as()->lhs()->isA()); - auto rhs = dynamic_cast(def->as()->rhs()); + dynamic_cast(alloc->shape().at(i)->definition()); + TORCH_CHECK( + def != nullptr && def->getBinaryOpType() == BinaryOpType::Add); + TORCH_CHECK(def->as()->lhs()->isA()); + auto rhs = dynamic_cast(def->as()->rhs()); TORCH_CHECK(rhs != nullptr && rhs->isConst()); int rhs_value = *rhs->value(); TORCH_CHECK(rhs_value == 1); @@ -1159,7 +1171,7 @@ TEST(NVFuserTest, FusionShiftGlobal_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1171,14 +1183,14 @@ TEST(NVFuserTest, FusionShiftGlobal_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { +TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {0, 1}); fusion.addOutput(tv3); @@ -1194,33 +1206,27 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { // t1 and t2 allocation: (split_factor1 + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { - TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor1 + 1); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 3; @@ -1229,15 +1235,15 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { +TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {1, 1}); fusion.addOutput(tv3); @@ -1271,35 +1277,30 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { // t1 and t2 allocation: (split_factor1 + 1) * (split_factor1 + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor1 + 1); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = shift(t0 + 1 + 2, {1, 1}); @@ -1307,7 +1308,7 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1327,7 +1328,7 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -1361,35 +1362,30 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { // cache allocation: (split_factor1 + 2) * (split_factor2 + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == tv0_cache->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor[i] + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -1401,7 +1397,7 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain1_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1416,15 +1412,15 @@ TEST(NVFuserTest, FusionShiftChain1_CUDA) { tv0->computeAt(tv2, -2); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = shift(shift(t0, {0, 1}), {0, 1}); @@ -1432,7 +1428,7 @@ TEST(NVFuserTest, FusionShiftChain1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain2_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1446,15 +1442,15 @@ TEST(NVFuserTest, FusionShiftChain2_CUDA) { tv0->computeAt(tv2, -2); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = shift(shift(t0, {0, 1}), {0, -1}); @@ -1462,13 +1458,13 @@ TEST(NVFuserTest, FusionShiftChain2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain3_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); auto tv3 = shift(tv2, {0, 1}); fusion.addOutput(tv3); @@ -1484,40 +1480,33 @@ TEST(NVFuserTest, FusionShiftChain3_CUDA) { // tv1: (split_factor + 2) // tv2: (split_factor + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 1); for (int i = 0; i < 1; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK(size != nullptr && size->isConst()); if (tensor_name == 1) { - TORCH_CHECK(rhs_value == 2); + TORCH_CHECK(size->value().value() == split_factor + 2); } else if (tensor_name == 2) { - TORCH_CHECK(rhs_value == 1); + TORCH_CHECK(size->value().value() == split_factor + 1); } } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1528,7 +1517,7 @@ TEST(NVFuserTest, FusionShiftChain3_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain4_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1558,42 +1547,36 @@ TEST(NVFuserTest, FusionShiftChain4_CUDA) { // tv2: (split_factor + 7) * (split_factor + 7) // tv3: (split_factor + 4) * (split_factor + 4) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK(size != nullptr && size->isConst()); + auto size_val = size->value().value(); if (tensor_name == 1) { - TORCH_CHECK(rhs_value == 9); + TORCH_CHECK(size_val == split_factor + 9); } else if (tensor_name == 2) { - TORCH_CHECK(rhs_value == 7); + TORCH_CHECK(size_val == split_factor + 7); } else if (tensor_name == 3) { - TORCH_CHECK(rhs_value == 4); + TORCH_CHECK(size_val == split_factor + 4); } } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = shift(t0, {1, -1}); @@ -1605,7 +1588,7 @@ TEST(NVFuserTest, FusionShiftChain4_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencilChain_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1625,7 +1608,8 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { tv_stencil1 = add(tv_stencil1, tv); } - tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1)); + tv_stencil1 = div( + tv_stencil1, IrBuilder::create(tv_stencil1_shifts.size() + 1)); // Second stencil: Same 5pt stencil std::vector tv_stencil2_shifts; @@ -1638,7 +1622,8 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { tv_stencil2 = add(tv_stencil2, tv); } - tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1)); + tv_stencil2 = div( + tv_stencil2, IrBuilder::create(tv_stencil2_shifts.size() + 1)); auto tv_out = tv_stencil2; @@ -1682,41 +1667,34 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { // tv0_cache: (split_factor + 4) * (split_factor + 4) // tv_stencil1: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == tv0_cache->name() || tensor_name == tv_stencil1->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i]); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK(size != nullptr && size->isConst()); if (tensor_name == tv0_cache->name()) { - TORCH_CHECK(rhs_value == 4); + TORCH_CHECK(size->value().value() == split_factor[i] + 4); } else if (tensor_name == tv_stencil1->name()) { - TORCH_CHECK(rhs_value == 2); + TORCH_CHECK(size->value().value() == split_factor[i] + 2); } } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto stencil1 = t0; @@ -1735,13 +1713,13 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { } // Shift a reduced tensor -TEST(NVFuserTest, FusionShiftReduction1_CUDA) { +TEST_F(NVFuserTest, FusionShiftReduction1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -1758,7 +1736,7 @@ TEST(NVFuserTest, FusionShiftReduction1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1770,13 +1748,13 @@ TEST(NVFuserTest, FusionShiftReduction1_CUDA) { } // Parallelized version of FusionShiftReduction1 -TEST(NVFuserTest, FusionShiftReduction2_CUDA) { +TEST_F(NVFuserTest, FusionShiftReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -1799,7 +1777,7 @@ TEST(NVFuserTest, FusionShiftReduction2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1810,13 +1788,13 @@ TEST(NVFuserTest, FusionShiftReduction2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { +TEST_F(NVFuserTest, FusionShiftRfactor1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -1841,7 +1819,7 @@ TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1852,7 +1830,7 @@ TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftBcast1_CUDA) { +TEST_F(NVFuserTest, FusionShiftBcast1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1877,7 +1855,7 @@ TEST(NVFuserTest, FusionShiftBcast1_CUDA) { std::vector inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t4 = t0.unsqueeze(-1).expand({numel_x, numel_y}) + t1; @@ -1886,7 +1864,7 @@ TEST(NVFuserTest, FusionShiftBcast1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftBcast2_CUDA) { +TEST_F(NVFuserTest, FusionShiftBcast2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1911,7 +1889,7 @@ TEST(NVFuserTest, FusionShiftBcast2_CUDA) { std::vector inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); @@ -1922,7 +1900,7 @@ TEST(NVFuserTest, FusionShiftBcast2_CUDA) { } // Combine ShiftBcast1 and ShiftBcast2 with parallelization -TEST(NVFuserTest, FusionShiftBcast3_CUDA) { +TEST_F(NVFuserTest, FusionShiftBcast3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1959,7 +1937,7 @@ TEST(NVFuserTest, FusionShiftBcast3_CUDA) { std::vector inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); @@ -1972,14 +1950,14 @@ TEST(NVFuserTest, FusionShiftBcast3_CUDA) { } // See issue #893 -TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { +TEST_F(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(2)); auto tv3 = add(tv1, tv2); auto tv4 = shift(tv3, {0, 1}); fusion.addOutput(tv4); @@ -1996,15 +1974,15 @@ TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -2016,14 +1994,14 @@ TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { } // See issue #893. Top-level placement. -TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { +TEST_F(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(2)); auto tv3 = add(tv1, tv2); auto tv4 = shift(tv3, {1}); fusion.addOutput(tv4); @@ -2037,14 +2015,14 @@ TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -2055,14 +2033,14 @@ TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { +TEST_F(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -2093,7 +2071,7 @@ TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { // along the Y dimension. The other 10 warps are used to load a 32x10 // tile, and all warps will do coalesced loads. No such optimization // is done in the fuser version. -TEST(NVFuserTest, FusionHdiff_CUDA) { +TEST_F(NVFuserTest, FusionHdiff_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2123,7 +2101,7 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { // T9 = T0 * 4 // T10 = T9 - T8 - auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors); + auto lap = sub(mul(inp, IrBuilder::create(4)), sum_of_neighbors); // T11 = shift(T10) // T12 = T11 - T10 @@ -2133,8 +2111,9 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { // T16 = T15 > 0 // T17 = T16 ? 0 : T12 auto flx_cond = - gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), new Double(0)); - auto flx0 = where(flx_cond, new Double(0), flx); + gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), + IrBuilder::create(0)); + auto flx0 = where(flx_cond, IrBuilder::create(0), flx); // T18 = shift(T10) // T19 = T18 - T10 @@ -2144,9 +2123,10 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { // T22 = T19 * T21 // T23 = T22 > 0 auto fly_cond = - gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), new Double(0)); + gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), + IrBuilder::create(0)); // T24 = T23 ? 0 : T19 - auto fly0 = where(fly_cond, new Double(0), fly); + auto fly0 = where(fly_cond, IrBuilder::create(0), fly); // T25 = shift(flx0) // T26 = T17 - T25 @@ -2233,9 +2213,6 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { } ///////////////////////////////// - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 101; int numel_y = 99; int numel_z = 10; @@ -2244,7 +2221,11 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); std::vector inputs = {inp_at, coeff_at}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto fuser_output = fe.runFusion(inputs)[0]; + // Trim the outer rim std::vector indices{ at::indexing::Slice(0, at::indexing::None), @@ -2273,7 +2254,7 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { } } -TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { +TEST_F(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2303,7 +2284,7 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { // T9 = T0 * 4 // T10 = T9 - T8 - auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors); + auto lap = sub(mul(inp, IrBuilder::create(4)), sum_of_neighbors); // T11 = shift(T10) // T12 = T11 - T10 @@ -2313,8 +2294,9 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { // T16 = T15 > 0 // T17 = T16 ? 0 : T12 auto flx_cond = - gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), new Double(0)); - auto flx0 = where(flx_cond, new Double(0), flx); + gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), + IrBuilder::create(0)); + auto flx0 = where(flx_cond, IrBuilder::create(0), flx); // T18 = shift(T10) // T19 = T18 - T10 @@ -2324,9 +2306,10 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { // T22 = T19 * T21 // T23 = T22 > 0 auto fly_cond = - gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), new Double(0)); + gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), + IrBuilder::create(0)); // T24 = T23 ? 0 : T19 - auto fly0 = where(fly_cond, new Double(0), fly); + auto fly0 = where(fly_cond, IrBuilder::create(0), fly); // T25 = shift(flx0) // T26 = T17 - T25 @@ -2428,9 +2411,6 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { } ///////////////////////////////// - FusionExecutor fe; - fe.compileFusion(&fusion); - const int halo_extent = 2; const int numel_x = 64 + halo_extent * 2; const int numel_y = 64 + halo_extent * 2; @@ -2440,7 +2420,11 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); std::vector inputs = {inp_at, coeff_at}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto fuser_output = fe.runFusion(inputs)[0]; + // Trim the outer rim std::vector indices{ at::indexing::Slice(0, at::indexing::None), @@ -2470,7 +2454,7 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { } // 3x3 max pooling -TEST(NVFuserTest, FusionMaxPooling_CUDA) { +TEST_F(NVFuserTest, FusionMaxPooling_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2537,9 +2521,6 @@ TEST(NVFuserTest, FusionMaxPooling_CUDA) { max_tensor->axis(0)->parallelize(ParallelType::BIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int hw = 50; const int num_channels = 20; const int pooling_window = 3; @@ -2555,6 +2536,8 @@ TEST(NVFuserTest, FusionMaxPooling_CUDA) { aten_inp = at::abs(aten_inp); std::vector inputs = {aten_inp}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = at::max_pool2d( @@ -2563,7 +2546,7 @@ TEST(NVFuserTest, FusionMaxPooling_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGatherPadding1_CUDA) { +TEST_F(NVFuserTest, FusionGather1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2586,13 +2569,13 @@ TEST(NVFuserTest, FusionGatherPadding1_CUDA) { auto ref = gather(t0, window_shape, padding_width); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); TORCH_CHECK(ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionGatherPadding2_CUDA) { +TEST_F(NVFuserTest, FusionGather2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2602,7 +2585,7 @@ TEST(NVFuserTest, FusionGatherPadding2_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width); @@ -2629,7 +2612,7 @@ TEST(NVFuserTest, FusionGatherPadding2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -2639,129 +2622,747 @@ TEST(NVFuserTest, FusionGatherPadding2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConv2DStatic_CUDA) { +TEST_F(NVFuserTest, FusionGather3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Input: [C, H, W] - auto inp = makeSymbolicTensor(3); - fusion.addInput(inp); - - // Weights: [K, C, 3, 3] - auto w = makeSymbolicTensor(4); - fusion.addInput(w); - - // Gather a neighbor tile of [3, 3] with padding size of 1 for each - // side of the spatial dimensions - auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}); - // inp_tile: [C, H, W, 1, 3, 3] - - auto inp_bc = - broadcast(inp_tile, {true, false, false, false, false, false, false}); - auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); - - auto inp_times_w = mul(inp_bc, w_bc); - - // Reduce the channel and neighbor tile dimensions - auto out = sum(inp_times_w, {1, 4, 5, 6}); - - fusion.addOutput(out); + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); - //////////////////////////////////// + const std::vector window_shape = {1, 3}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; - // Cache the input and weight tensors - auto inp_cache = inp->cache_after(); + auto tv1 = gather(tv0, window_shape, padding_width); - // Blocking the spatial dimensions - const int block_w = 16; - const int block_h = 4; - // Blocking the channel dimension - const int block_c = 8; + fusion.addOutput(tv1); - out->split(2, block_h); - out->split(4, block_w); - out->reorder({{3, 4}}); - // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + const int s1 = 11; + const int s2 = 13; - out->split(1, block_c); - // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); - auto out_rf = out->rFactor({1, -3, -2, -1}); - // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] - // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); - // Create a [block_x, block_y] tile on smem - inp_cache->computeAt(out, 4); - // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] - inp_cache->setMemoryType(MemoryType::Shared); + auto ref = gather(t0, window_shape, padding_width); + TORCH_CHECK(ref.equal(outputs[0])); +} - // Move Ci forward - out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); - inp_cache->computeAt(out_rf, 5); +TEST_F(NVFuserTest, FusionGather4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); - inp_tile->computeAt(out_rf, -1); - w->computeAt(out_rf, -1); + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); - out->axis(0)->parallelize(ParallelType::BIDx); - out->axis(1)->parallelize(ParallelType::TIDz); - out->axis(4)->parallelize(ParallelType::TIDy); - out->axis(5)->parallelize(ParallelType::TIDx); + const std::vector window_shape = {3, 3}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; - scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + auto tv1 = gather(tv0, window_shape, padding_width); - FusionExecutor fe; - fe.compileFusion(&fusion); + fusion.addOutput(tv1); - const int dim_h = 99; - const int dim_w = 101; - const int dim_c = 10; - const int dim_f = 20; + const int s1 = 11; + const int s2 = 13; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); - at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); - std::vector inputs = {at_inp, at_w}; + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); - auto cg_outputs = fe.runFusion(inputs); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); - at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis - auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); - at_out = at_out.squeeze(0); // drop the N axis + auto ref = gather(t0, window_shape, padding_width); - testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); + TORCH_CHECK(ref.equal(outputs[0])); } -// Mostly the same as the static conv test, but the shape of the weights, -// 3x3 in this case, is given dynamically -TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { +TEST_F(NVFuserTest, FusionGather5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Input: [C, H, W] - auto inp = makeSymbolicTensor(3); - fusion.addInput(inp); + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); - // Weights: [K, C, S, T] - auto w = makeSymbolicTensor(4); - fusion.addInput(w); + const std::vector window_shape = {3, 3}; + const std::vector> padding_width = {{1, 0}, {0, 1}}; - auto w_h = new Int(); - fusion.addInput(w_h); - auto w_w = new Int(); - fusion.addInput(w_w); + auto tv1 = gather(tv0, window_shape, padding_width); - auto pad_h = new Int(); - fusion.addInput(pad_h); - auto pad_w = new Int(); - fusion.addInput(pad_w); + fusion.addOutput(tv1); - // Gather a neighbor tile of [w_dim_h, w_dim_w] with padding + const int s1 = 11; + const int s2 = 13; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// Conv-like pattern with no padding +TEST_F(NVFuserTest, FusionGather6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {3, 4}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; + + auto tv1 = gather(tv0, window_shape, padding_width); + + fusion.addOutput(tv1); + + // Blocking the spatial dimensions + const int block_x = 16; + const int block_y = 8; + + auto tv0_cache = tv0->cache_after(); + auto out = tv1; + auto out_cache = out->cache_before(); + + out->split(1, block_x); + out->split(0, block_y); + out->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->axis(0)->parallelize(ParallelType::BIDy); + out->axis(1)->parallelize(ParallelType::BIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + const int s1 = 101; + const int s2 = 99; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// Conv-like pattern with irregular padding +TEST_F(NVFuserTest, FusionGather7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {3, 4}; + const std::vector> padding_width = {{0, 2}, {2, 1}}; + + auto tv1 = gather(tv0, window_shape, padding_width); + + fusion.addOutput(tv1); + + // Blocking the spatial dimensions + const int block_x = 16; + const int block_y = 8; + + auto tv0_cache = tv0->cache_after(); + auto out = tv1; + auto out_cache = out->cache_before(); + + out->split(1, block_x); + out->split(0, block_y); + out->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->axis(0)->parallelize(ParallelType::BIDy); + out->axis(1)->parallelize(ParallelType::BIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + const int s1 = 101; + const int s2 = 99; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// With no padding but with striding +TEST_F(NVFuserTest, FusionGather8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {2, 3}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; + const std::vector strides = {3, 3}; + + auto tv1 = gather(tv0, window_shape, padding_width, strides); + + fusion.addOutput(tv1); + + const int s1 = 11; + const int s2 = 13; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + for (const auto i : c10::irange(size.size())) { + size[i] = ceilDiv( + size[i] - window_shape[i] + 1 + padding_width[i][0] + + padding_width[i][1], + strides[i]); + } + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width, strides); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// Similar to Gather8 but with splitting and parallelization +TEST_F(NVFuserTest, FusionGather9_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {3, 4}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; + const std::vector strides = {2, 2}; + + auto tv1 = gather(tv0, window_shape, padding_width, strides); + + fusion.addOutput(tv1); + + // Blocking the spatial dimensions + const int block_x = 16; + const int block_y = 8; + + auto tv0_cache = tv0->cache_after(); + auto out = tv1; + auto out_cache = out->cache_before(); + + out->split(1, block_x); + out->split(0, block_y); + out->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->axis(0)->parallelize(ParallelType::BIDy); + out->axis(1)->parallelize(ParallelType::BIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + const int s1 = 101; + const int s2 = 99; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + for (const auto i : c10::irange(size.size())) { + size[i] = ceilDiv( + size[i] - window_shape[i] + 1 + padding_width[i][0] + + padding_width[i][1], + strides[i]); + } + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width, strides); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +TEST_F(NVFuserTest, FusionConv2D_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 3, 3] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [3, 3] with padding size of 1 for each + // side of the spatial dimensions + auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}); + // inp_tile: [C, H, W, 1, 3, 3] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); + + //////////////////////////////////// + + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; + + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); + + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv2DNoPadding_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 3, 3] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [3, 3] with no padding + auto inp_tile = + gather(inp, {1, 3, 3}, {{0, 0}, {0, 0}, {0, 0}}, {1, 1, 1}, true); + // inp_tile: [C, H-2, W-2, 1, 3, 3] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); + + //////////////////////////////////// + + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; + + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); + + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + std::vector stride = {1, 1}; + std::vector padding = {0, 0}; + auto at_out = at::conv2d(at_inp, at_w, {}, stride, padding); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv2DNoPaddingStrided_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 3, 3] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [2, 2] with no padding and strides of + // [2, 2] + auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 0}, {0, 0}}, {1, 2, 2}); + // inp_tile: [C, H/2, W/2, 1, 2, 2] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); + + //////////////////////////////////// + + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; + + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); + + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + std::vector stride = {2, 2}; + std::vector padding = {0, 0}; + auto at_out = at::conv2d(at_inp, at_w, {}, stride, padding); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +// 5x5 followed by 3x3 +TEST_F(NVFuserTest, FusionConv2DChain_CUDA) { + const int dim_w1_h = 5; + const int dim_w1_w = 5; + const int dim_pad1_h = (dim_w1_h - 1) / 2; + const int dim_pad1_w = (dim_w1_w - 1) / 2; + const int dim_w2_h = 3; + const int dim_w2_w = 3; + const int dim_pad2_h = (dim_w2_h - 1) / 2; + const int dim_pad2_w = (dim_w2_w - 1) / 2; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [K1, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K2, K1, S1, T1] + auto w1 = makeSymbolicTensor(4); + fusion.addInput(w1); + + // Weights: [K3, K2, S2, T2] + auto w2 = makeSymbolicTensor(4); + fusion.addInput(w2); + + // Gather a neighbor tile of [w1_h, w1_w] with padding auto inp_tile = gather( inp, - {new Int(1), w_h, w_w}, - {{new Int(0), new Int(0)}, {pad_h, pad_h}, {pad_w, pad_w}}); - // inp_tile: [C, 1, H - w_h + 1, W - w_w + 1, w_h, w_w] + {1, dim_w1_h, dim_w1_w}, + {{0, 0}, {dim_pad1_h, dim_pad1_h}, {dim_pad1_w, dim_pad1_w}}); + // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false}); + + auto inp_times_w1 = mul(inp_bc, w1_bc); + + // Reduce the channel and neighbor tile dimensions + auto out1 = sum(inp_times_w1, {1, 4, 5, 6}); + + // Second conv + auto out1_tile = gather( + out1, + {1, dim_w2_h, dim_w2_w}, + {{0, 0}, {dim_pad2_h, dim_pad2_h}, {dim_pad2_w, dim_pad2_w}}); + + auto out1_bc = + broadcast(out1_tile, {true, false, false, false, false, false, false}); + auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false}); + + auto out1_times_w2 = mul(out1_bc, w2_bc); + + auto out2 = sum(out1_times_w2, {1, 4, 5, 6}); + + fusion.addOutput(out2); + + //////////////////////////////////// + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + + out2->split(2, block_h); + out2->split(4, block_w); + out2->reorder({{3, 4}}); + // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out2, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out1->reorder({{5, 3}, {3, 4}, {4, 5}}); + out1->setMemoryType(MemoryType::Shared); + + inp_cache->computeAt(out1, 4); + + inp_tile->computeAt(out1, -1); + w1->computeAt(out1, -1); + + out1_tile->computeAt(out2, -1); + w2->computeAt(out2, -1); + + out2->axis(0)->parallelize(ParallelType::BIDx); + out2->axis(4)->parallelize(ParallelType::TIDy); + out2->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_k1 = 3; + const int dim_k2 = 5; + const int dim_k3 = 7; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options); + at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options); + at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options); + std::vector inputs = {at_inp, at_w1, at_w2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2); + auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1); + at_out2 = at_out2.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 2, 2] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [2, 2] with padding size of 1 only for + // the right side of the spatial dimensions. The left padding is + // zero so that the output axis stays the same. + auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}}); + // inp_tile: [C, H, W, 1, 2, 2] auto inp_bc = broadcast(inp_tile, {true, false, false, false, false, false, false}); @@ -2775,6 +3376,7 @@ TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { fusion.addOutput(out); //////////////////////////////////// + // Cache the input and weight tensors auto inp_cache = inp->cache_after(); @@ -2787,13 +3389,13 @@ TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { out->split(2, block_h); out->split(4, block_w); out->reorder({{3, 4}}); - // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2] out->split(1, block_c); - // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] auto out_rf = out->rFactor({1, -3, -2, -1}); - // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] // out_rf: [K, Ci, Ho, Wo, Hi, Wi] // Create a [block_x, block_y] tile on smem @@ -2815,185 +3417,226 @@ TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 99; const int dim_w = 101; const int dim_c = 10; const int dim_f = 20; - const int dim_w_h = 3; - const int dim_w_w = 3; - const int dim_pad_h = (dim_w_h - 1) / 2; - const int dim_pad_w = (dim_w_w - 1) / 2; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); - at::Tensor at_w = at::randn({dim_f, dim_c, dim_w_h, dim_w_w}, options); - std::vector inputs = { - at_inp, at_w, dim_w_h, dim_w_w, dim_pad_h, dim_pad_w}; + at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); + std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); at_out = at_out.squeeze(0); // drop the N axis + // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas + // the fuser output has dim_h*dim_w. Drop the first elements to make + // it match with the fuser output. + std::vector indices{ + at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(1, at::indexing::None), + at::indexing::Slice(1, at::indexing::None)}; + at_out = at_out.index(indices); testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -// 5x5 followed by 3x3 -TEST(NVFuserTest, FusionConv2DDynamicChain_CUDA) { +TEST_F(NVFuserTest, FusionConv4x4Pad1x1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Input: [K1, H, W] + // Input: [C, H, W] auto inp = makeSymbolicTensor(3); fusion.addInput(inp); - // Weights: [K2, K1, S1, T1] - auto w1 = makeSymbolicTensor(4); - fusion.addInput(w1); + // Weights: [K, C, 4, 4] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); - // Weights: [K3, K2, S2, T2] - auto w2 = makeSymbolicTensor(4); - fusion.addInput(w2); + // Gather a neighbor tile of [4, 4] with padding size of 1 for both + // sides of the spatial dimensions. The resulting extent is + // decreased by one. + auto inp_tile = + gather(inp, {1, 4, 4}, {{0, 0}, {1, 1}, {1, 1}}, {1, 1, 1}, true); + // inp_tile: [C, H-1, W-1, 1, 4, 4] - auto w1_h = new Int(); - fusion.addInput(w1_h); - auto w1_w = new Int(); - fusion.addInput(w1_w); + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); - auto w2_h = new Int(); - fusion.addInput(w2_h); - auto w2_w = new Int(); - fusion.addInput(w2_w); + auto inp_times_w = mul(inp_bc, w_bc); - auto pad_h1 = new Int(); - fusion.addInput(pad_h1); - auto pad_w1 = new Int(); - fusion.addInput(pad_w1); + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); - auto pad_h2 = new Int(); - fusion.addInput(pad_h2); - auto pad_w2 = new Int(); - fusion.addInput(pad_w2); + fusion.addOutput(out); - // Gather a neighbor tile of [w1_h, w1_w] with padding - auto inp_tile = gather( - inp, - {new Int(1), w1_h, w1_w}, - {{new Int(0), new Int(0)}, {pad_h1, pad_h1}, {pad_w1, pad_w1}}); - // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w] + //////////////////////////////////// - auto inp_bc = - broadcast(inp_tile, {true, false, false, false, false, false, false}); - auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false}); + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); - auto inp_times_w1 = mul(inp_bc, w1_bc); + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; - // Reduce the channel and neighbor tile dimensions - auto out1 = sum(inp_times_w1, {1, 4, 5, 6}); + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 4, 4] - // Second conv - auto out1_tile = gather( - out1, - {new Int(1), w2_h, w2_w}, - {{new Int(0), new Int(0)}, {pad_h2, pad_h2}, {pad_w2, pad_w2}}); + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 4, 4] - auto out1_bc = - broadcast(out1_tile, {true, false, false, false, false, false, false}); - auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false}); + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 4, 4] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] - auto out1_times_w2 = mul(out1_bc, w2_bc); + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); - auto out2 = sum(out1_times_w2, {1, 4, 5, 6}); + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); - fusion.addOutput(out2); + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 4, 4}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + auto at_out = + at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 1, 1); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv4x5Pad1x2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 4, 4] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [4, 5] with padding size of 1 and 2 for + // each side of the spatial dimensions. + auto inp_tile = + gather(inp, {1, 4, 5}, {{0, 0}, {1, 1}, {2, 2}}, {1, 1, 1}, true); + // inp_tile: [C, H-1, W, 1, 4, 5] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); //////////////////////////////////// + // Cache the input and weight tensors auto inp_cache = inp->cache_after(); // Blocking the spatial dimensions const int block_w = 16; const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; - out2->split(2, block_h); - out2->split(4, block_w); - out2->reorder({{3, 4}}); - // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3] + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 4, 5] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 4, 5] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 4, 5] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] // Create a [block_x, block_y] tile on smem - inp_cache->computeAt(out2, 4); + inp_cache->computeAt(out, 4); // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] inp_cache->setMemoryType(MemoryType::Shared); // Move Ci forward - out1->reorder({{5, 3}, {3, 4}, {4, 5}}); - out1->setMemoryType(MemoryType::Shared); - - inp_cache->computeAt(out1, 4); - - inp_tile->computeAt(out1, -1); - w1->computeAt(out1, -1); - - out1_tile->computeAt(out2, -1); - w2->computeAt(out2, -1); + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); - out2->axis(0)->parallelize(ParallelType::BIDx); - out2->axis(4)->parallelize(ParallelType::TIDy); - out2->axis(5)->parallelize(ParallelType::TIDx); + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); - scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1}); + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); const int dim_h = 99; const int dim_w = 101; - const int dim_k1 = 3; - const int dim_k2 = 5; - const int dim_k3 = 7; - const int dim_w1_h = 5; - const int dim_w1_w = 5; - const int dim_pad1_h = (dim_w1_h - 1) / 2; - const int dim_pad1_w = (dim_w1_w - 1) / 2; - const int dim_w2_h = 3; - const int dim_w2_w = 3; - const int dim_pad2_h = (dim_w2_h - 1) / 2; - const int dim_pad2_w = (dim_w2_w - 1) / 2; + const int dim_c = 10; + const int dim_f = 20; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); - at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options); - at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options); - at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options); - std::vector inputs = { - at_inp, - at_w1, - at_w2, - dim_w1_h, - dim_w1_w, - dim_w2_h, - dim_w2_w, - dim_pad1_h, - dim_pad1_w, - dim_pad2_h, - dim_pad2_w}; + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 4, 5}, options); + std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis - auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2); - auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1); - at_out2 = at_out2.squeeze(0); // drop the N axis + auto at_out = + at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 1, {1, 2}); + at_out = at_out.squeeze(0); // drop the N axis - testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { +TEST_F(NVFuserTest, FusionConv4x4Pad1x1Stride4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3001,15 +3644,14 @@ TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { auto inp = makeSymbolicTensor(3); fusion.addInput(inp); - // Weights: [K, C, 2, 2] + // Weights: [K, C, 3, 3] auto w = makeSymbolicTensor(4); fusion.addInput(w); - // Gather a neighbor tile of [2, 2] with padding size of 1 only for - // the right side of the spatial dimensions. The left padding is - // zero so that the output axis stays the same. - auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}}); - // inp_tile: [C, H, W, 1, 2, 2] + // Gather a neighbor tile of [4, 4] with padding size of 1 for both + // sides of the spatial dimensions. Set the stride width as 4. + auto inp_tile = gather(inp, {1, 4, 4}, {{0, 0}, {1, 1}, {1, 1}}, {1, 4, 4}); + // inp_tile: [C, H/4, s4, W/4, s4, 1, 4, 4] auto inp_bc = broadcast(inp_tile, {true, false, false, false, false, false, false}); @@ -3030,43 +3672,49 @@ TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { // Blocking the spatial dimensions const int block_w = 16; const int block_h = 4; - // Blocking the channel dimension - const int block_c = 8; + const int block_c = 2; + // [K, C, H/s, W/s, 1, 4, 4] out->split(2, block_h); + // [K, C, H/s/block_h, block_h, W/s, 1, 4, 4] out->split(4, block_w); + // [K, C, H/s/block_h, block_h, W/s/block_w, block_w, 1, 4, 4] out->reorder({{3, 4}}); - // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2] - + // [K, C, H/s/block_h, W/s/block_w, block_h, block_w, 1, 4, 4] out->split(1, block_c); - // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] + // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, block_h, block_w, 1, 4, + // 4] + out->split(4, 1); + // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, + // 4, 4] auto out_rf = out->rFactor({1, -3, -2, -1}); - // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] - // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, + // 4, 4] - // Create a [block_x, block_y] tile on smem - inp_cache->computeAt(out, 4); - // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + // out: [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w] + + inp_cache->computeAt(out, 5); inp_cache->setMemoryType(MemoryType::Shared); + // [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, C/block_c, 1, + // 4, 4] - // Move Ci forward - out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); - inp_cache->computeAt(out_rf, 5); + // Move C/block_c before block_h/2 and share the domain from + // inp_cache to out_rf + out_rf->reorder({{7, 5}, {5, 6}, {6, 7}}); + inp_cache->computeAt(out_rf, 6); inp_tile->computeAt(out_rf, -1); w->computeAt(out_rf, -1); out->axis(0)->parallelize(ParallelType::BIDx); out->axis(1)->parallelize(ParallelType::TIDz); - out->axis(4)->parallelize(ParallelType::TIDy); - out->axis(5)->parallelize(ParallelType::TIDx); + out->axis(4)->parallelize(ParallelType::Unswitch); + out->axis(5)->parallelize(ParallelType::TIDy); + out->axis(6)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 99; const int dim_w = 101; const int dim_c = 10; @@ -3075,28 +3723,23 @@ TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); - at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 4, 4}, options); std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis - auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); + auto at_out = + at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 4, {1, 1}); at_out = at_out.squeeze(0); // drop the N axis - // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas - // the fuser output has dim_h*dim_w. Drop the first elements to make - // it match with the fuser output. - std::vector indices{ - at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(1, at::indexing::None), - at::indexing::Slice(1, at::indexing::None)}; - at_out = at_out.index(indices); testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } // POC implementation of im2col for 3-by-3 kernels -TEST(NVFuserTest, FusionIm2Col_CUDA) { +TEST_F(NVFuserTest, FusionIm2Col_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3147,9 +3790,6 @@ TEST(NVFuserTest, FusionIm2Col_CUDA) { scheduler_utils::parallelizeAllLike(out, {inp_cache, inp_tile}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 31; const int dim_w = 33; const int dim_c = 5; @@ -3160,6 +3800,8 @@ TEST(NVFuserTest, FusionIm2Col_CUDA) { at::Tensor at_inp = at::randn({dim_n, dim_c, dim_h, dim_w}, options); std::vector inputs = {at_inp}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); auto at_out = at::im2col(at_inp, {3, 3}, {1, 1}, {1, 1}, {1, 1}); @@ -3171,14 +3813,14 @@ TEST(NVFuserTest, FusionIm2Col_CUDA) { testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPadding1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3201,9 +3843,6 @@ TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { tv5->axis(-2)->parallelize(ParallelType::TIDy); scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3211,6 +3850,9 @@ TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -3226,14 +3868,14 @@ TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { } // Split and merge -TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPadding2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3256,9 +3898,6 @@ TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { tv5->axis(-1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3266,6 +3905,9 @@ TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -3281,14 +3923,14 @@ TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { } // Split and merge, then welford -TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPadding3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3316,9 +3958,6 @@ TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { tv_avg->axis(-1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv_avg, ir_utils::allTvs(&fusion)); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3327,7 +3966,11 @@ TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); + outputs[1] /= (numel_x - 2) * (numel_y - 2); auto t1 = t0 + 1; @@ -3346,13 +3989,13 @@ TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { } // Shift indexing and predication with contiguous merge -TEST(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, true); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3366,15 +4009,15 @@ TEST(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { tv2->setMemoryType(MemoryType::Global); tv3->setMemoryType(MemoryType::Global); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 9; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3392,14 +4035,14 @@ TEST(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { testValidate(&fusion, {fuser_out}, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv2, {1, -1}, false); auto tv4 = sum(tv3, {0, 1}); @@ -3422,9 +4065,6 @@ TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3}); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3433,6 +4073,9 @@ TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -3447,14 +4090,14 @@ TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { } // Rfactor is not allowed with partial domains -TEST(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = sum(tv2, {0, 1}); fusion.addOutput(tv3); @@ -3466,7 +4109,61 @@ TEST(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { ASSERT_ANY_THROW(tv3->rFactor({-2})); } -TEST(NVFuserTest, FusionPartialSplit1_CUDA) { +TEST_F(NVFuserTest, FusionShiftPadding1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = shift(tv1, {2, -2}, {1, 1}); + auto tv3 = shift(tv1, {-3, 2}, {2, 2}); + auto tv4 = add(tv2, tv3); + auto tv5 = sum(tv4, {0, 1}); + + fusion.addOutput(tv5); + + tv1->setMemoryType(MemoryType::Shared); + + tv5->split(0, 4); + tv5->split(-1, 8); + tv5->reorder({{1, 2}}); + + TransformPropagator::from(tv5); + + tv2->computeAt(tv5, -1); + tv3->computeAt(tv5, -1); + + tv5->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-2)->parallelize(ParallelType::TIDy); + scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); + + int numel_x = 99; + int numel_y = 101; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto t1 = t0 + 1; + auto t2 = shift(t1, {2, -2}); + auto t3 = shift(t1, {-3, 2}); + auto t4 = t2 + t3; + std::vector indices{ + at::indexing::Slice(1, -1), at::indexing::Slice(0, -1)}; + t4 = t4.index(indices); + auto ref = t4.sum(at::ArrayRef{0, 1}); + + testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPartialSplit1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3474,7 +4171,7 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { // [I] fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(0)); + auto tv1 = add(tv0, IrBuilder::create(0)); // [I] auto tv2 = shift(tv1, {1}, false); // [1:I] @@ -3504,9 +4201,6 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - // gridDim.x is ceilDiv(numel_x - 2, 8), not ceilDiv(numel_x, 8), // so it's going to be just 2 rather than 3. const int numel_x = 18; @@ -3527,6 +4221,9 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{at::indexing::Slice(1, -1)}; @@ -3538,21 +4235,21 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPartialSplit2_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(0)); + auto tv1 = add(tv0, IrBuilder::create(0)); auto tv2 = shift(tv1, {1}, false); auto tv3 = shift(tv1, {-1}, false); auto tv4 = add(tv2, tv3); fusion.addOutput(tv4); - auto tv5 = add(tv1, new Double(1)); - auto tv6 = add(tv5, new Double(1)); + auto tv5 = add(tv1, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); fusion.addOutput(tv6); tv4->split(0, 4, true, true); @@ -3568,14 +4265,14 @@ TEST(NVFuserTest, FusionPartialSplit2_CUDA) { } // 2D version of PartialSplit1 -TEST(NVFuserTest, FusionPartialSplit3_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(0)); + auto tv1 = add(tv0, IrBuilder::create(0)); auto tv2 = shift(tv1, {1, 2}, false); auto tv3 = shift(tv1, {-2, -1}, false); auto tv4 = add(tv2, tv3); @@ -3595,9 +4292,6 @@ TEST(NVFuserTest, FusionPartialSplit3_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 32 + 3; const int numel_y = 32 + 3; @@ -3606,6 +4300,9 @@ TEST(NVFuserTest, FusionPartialSplit3_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3620,7 +4317,7 @@ TEST(NVFuserTest, FusionPartialSplit3_CUDA) { // Almost same fusion with Shift5ptStencilChain but non-padded shift // and partial split. -TEST(NVFuserTest, FusionPartialSplit4_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3641,7 +4338,8 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { tv_stencil1 = add(tv_stencil1, tv); } - tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1)); + tv_stencil1 = div( + tv_stencil1, IrBuilder::create(tv_stencil1_shifts.size() + 1)); // Second stencil: Same 5pt stencil std::vector tv_stencil2_shifts; @@ -3654,7 +4352,8 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { tv_stencil2 = add(tv_stencil2, tv); } - tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1)); + tv_stencil2 = div( + tv_stencil2, IrBuilder::create(tv_stencil2_shifts.size() + 1)); auto tv_out = tv_stencil2; @@ -3696,9 +4395,6 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { tv0_cache->setMemoryType(MemoryType::Shared); tv_stencil1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - // Input matrix size is 68x68, and the output is 64x64. Both // gridDim.x and gridim.y should be ceilDiv(numel - 4, // split_factor), which is 4. If full split is used, the grid @@ -3709,6 +4405,9 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3731,7 +4430,7 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPartialSplit5_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3743,7 +4442,7 @@ TEST(NVFuserTest, FusionPartialSplit5_CUDA) { fusion.addInput(tv0); auto tv1 = shift(tv0, {0, 1}, false); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); @@ -3760,12 +4459,12 @@ TEST(NVFuserTest, FusionPartialSplit5_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3779,7 +4478,7 @@ TEST(NVFuserTest, FusionPartialSplit5_CUDA) { testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPartialSplit6_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3788,9 +4487,9 @@ TEST(NVFuserTest, FusionPartialSplit6_CUDA) { auto tv0 = makeConcreteTensor({numel_x}); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1}, false); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -3803,12 +4502,12 @@ TEST(NVFuserTest, FusionPartialSplit6_CUDA) { tv1->setMemoryType(MemoryType::Shared); tv2->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3821,7 +4520,7 @@ TEST(NVFuserTest, FusionPartialSplit6_CUDA) { testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { +TEST_F(NVFuserTest, FusionShiftUnswitch1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3840,7 +4539,7 @@ TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { auto tv4 = shift(tv0, {-2, -2}); fusion.addOutput(tv4); - auto tv5 = add(tv0, new Double(1)); + auto tv5 = add(tv0, IrBuilder::create(1)); auto tv6 = shift(tv5, {0, -1}); fusion.addOutput(tv6); @@ -3862,7 +4561,7 @@ TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = shift(t0, {-1, 0}); @@ -3881,27 +4580,22 @@ TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { TORCH_CHECK(t6.equal(outputs[4])); } -TEST(NVFuserTest, FusionGatherUnswitch1_CUDA) { +TEST_F(NVFuserTest, FusionGatherUnswitch1_CUDA) { + const int tv1_gather = 3; + const int tv1_gather_pad = 1; + const int tv2_gather = 5; + const int tv2_gather_pad = 2; + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1_gather_param = new Int(); - fusion.addInput(tv1_gather_param); - auto tv1_gather_pad_param = new Int(); - fusion.addInput(tv1_gather_pad_param); - auto tv1 = gather( - tv0, {tv1_gather_param}, {{tv1_gather_pad_param, tv1_gather_pad_param}}); + auto tv1 = gather(tv0, {tv1_gather}, {{tv1_gather_pad, tv1_gather_pad}}); fusion.addOutput(tv1); - auto tv2_gather_param = new Int(); - fusion.addInput(tv2_gather_param); - auto tv2_gather_pad_param = new Int(); - fusion.addInput(tv2_gather_pad_param); - auto tv2 = gather( - tv0, {tv2_gather_param}, {{tv2_gather_pad_param, tv2_gather_pad_param}}); + auto tv2 = gather(tv0, {tv2_gather}, {{tv2_gather_pad, tv2_gather_pad}}); fusion.addOutput(tv2); // Static gather @@ -3923,18 +4617,13 @@ TEST(NVFuserTest, FusionGatherUnswitch1_CUDA) { tv4->axis(1)->parallelize(ParallelType::TIDx); const int numel_x = 100; - const int tv1_gather = 3; - const int tv1_gather_pad = 1; - const int tv2_gather = 5; - const int tv2_gather_pad = 2; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); - std::vector inputs = { - t0, tv1_gather, tv1_gather_pad, tv2_gather, tv2_gather_pad}; + std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = gather(t0, {tv1_gather}, {{tv1_gather_pad, tv1_gather_pad}}); @@ -3950,7 +4639,7 @@ TEST(NVFuserTest, FusionGatherUnswitch1_CUDA) { TORCH_CHECK(t4.equal(outputs[3])); } -TEST(NVFuserTest, FusionGatherStrided1_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3973,7 +4662,7 @@ TEST(NVFuserTest, FusionGatherStrided1_CUDA) { at::Tensor t0 = at::randn({s1, s2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // tv1 has a stride dimension, so its number of dimensions should be @@ -4013,7 +4702,7 @@ TEST(NVFuserTest, FusionGatherStrided1_CUDA) { } // Split strided domain -TEST(NVFuserTest, FusionGatherStrided2_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4024,7 +4713,7 @@ TEST(NVFuserTest, FusionGatherStrided2_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4054,7 +4743,7 @@ TEST(NVFuserTest, FusionGatherStrided2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4065,7 +4754,7 @@ TEST(NVFuserTest, FusionGatherStrided2_CUDA) { } // Outer split -TEST(NVFuserTest, FusionGatherStrided3_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4076,7 +4765,7 @@ TEST(NVFuserTest, FusionGatherStrided3_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4102,7 +4791,7 @@ TEST(NVFuserTest, FusionGatherStrided3_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4112,7 +4801,7 @@ TEST(NVFuserTest, FusionGatherStrided3_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGatherStrided4_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4123,7 +4812,7 @@ TEST(NVFuserTest, FusionGatherStrided4_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); // Test propagation of split from one gather output to another auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4147,7 +4836,7 @@ TEST(NVFuserTest, FusionGatherStrided4_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4158,7 +4847,7 @@ TEST(NVFuserTest, FusionGatherStrided4_CUDA) { } // Same as GatherStrided1 but with stride != window -TEST(NVFuserTest, FusionGatherStrided5_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4181,7 +4870,7 @@ TEST(NVFuserTest, FusionGatherStrided5_CUDA) { at::Tensor t0 = at::randn({s1, s2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); auto ref = gather(t0, window_shape, padding_width, strides); @@ -4190,7 +4879,7 @@ TEST(NVFuserTest, FusionGatherStrided5_CUDA) { } // Same as GatherStrided2 but with stride != window -TEST(NVFuserTest, FusionGatherStrided6_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4201,7 +4890,7 @@ TEST(NVFuserTest, FusionGatherStrided6_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4231,7 +4920,7 @@ TEST(NVFuserTest, FusionGatherStrided6_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4242,7 +4931,7 @@ TEST(NVFuserTest, FusionGatherStrided6_CUDA) { } // Same as GatherStrided4 but different strides -TEST(NVFuserTest, FusionGatherStrided7_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4252,7 +4941,7 @@ TEST(NVFuserTest, FusionGatherStrided7_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); // Use different strides auto tv2 = gather(tv1, window_shape, padding_width, {3}); @@ -4271,7 +4960,7 @@ TEST(NVFuserTest, FusionGatherStrided7_CUDA) { } // Same as GatherStrided2 but with unswitch -TEST(NVFuserTest, FusionGatherStrided8_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided8_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4282,7 +4971,7 @@ TEST(NVFuserTest, FusionGatherStrided8_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4316,7 +5005,7 @@ TEST(NVFuserTest, FusionGatherStrided8_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4327,7 +5016,7 @@ TEST(NVFuserTest, FusionGatherStrided8_CUDA) { } // Chained strided gather. Not supported yet. -TEST(NVFuserTest, FusionGatherStridedChain_CUDA) { +TEST_F(NVFuserTest, FusionGatherStridedChain_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4339,7 +5028,7 @@ TEST(NVFuserTest, FusionGatherStridedChain_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); // Reduce gathered window @@ -4356,10 +5045,7 @@ TEST(NVFuserTest, FusionGatherStridedChain_CUDA) { ASSERT_ANY_THROW(GpuLower gpulw(&fusion)); } -TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionMaxPoolingStrided_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4379,7 +5065,7 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { auto max_tensor = reductionOp( BinaryOpType::Max, {-3, -2, -1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), inp_tile); fusion.addOutput(max_tensor); @@ -4410,9 +5096,6 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { inp_cache->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int hw = 50; const int num_channels = 20; const int pooling_window = 3; @@ -4428,6 +5111,8 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { aten_inp = at::abs(aten_inp); std::vector inputs = {aten_inp}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = at::max_pool2d( @@ -4436,10 +5121,7 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionConv2DStaticStrided_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4518,9 +5200,6 @@ TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 99; const int dim_w = 101; const int dim_c = 10; @@ -4532,6 +5211,8 @@ TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis @@ -4541,14 +5222,14 @@ TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {-1}); fusion.addOutput(tv2); @@ -4564,7 +5245,7 @@ TEST(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = shift((t0 + 1), {-1}); @@ -4572,7 +5253,7 @@ TEST(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4621,7 +5302,7 @@ TEST(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { at::Tensor t0 = at::randn({111, 222}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto t1 = gather(t0, {3, 3}, {{1, 1}, {1, 1}}); @@ -4632,6 +5313,59 @@ TEST(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {t4}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionGather9ptStencilDoubleBuffering_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = gather(tv0, {3, 3}, {{1, 1}, {1, 1}}); + auto tv2 = sum(tv1, {-2, -1}); + auto tv3 = div(tv2, IrBuilder::create(9)); + + auto out = tv3; + + fusion.addOutput(out); + + auto tv0_cache = tv0->cache_after(); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->split(-2, 4); + out->split(-1, 32); + out->reorder({{1, 2}, {2, 1}}); + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + out->axis(3)->parallelize(ParallelType::TIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(0)->parallelize(ParallelType::BIDx); + + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + tv0_cache->doubleBuffer(); + + int numel_x = 99; + int numel_y = 101; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto t1 = gather(t0, {3, 3}, {{1, 1}, {1, 1}}); + auto t2 = sum(t1, {-2, -1}); + auto t3 = t2 / 9; + auto ref = t3; + + testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/test/cpp/jit/test_gpu_validator.h b/test/cpp/jit/test_gpu_validator.h index 5923e384e39..4b01f361cfc 100644 --- a/test/cpp/jit/test_gpu_validator.h +++ b/test/cpp/jit/test_gpu_validator.h @@ -4,6 +4,7 @@ #include #include +#include #include namespace torch { @@ -11,6 +12,25 @@ namespace jit { namespace fuser { namespace cuda { +inline bool deviceMajorMinorCheck(int major, int minor = 0) { + auto dev_prop = at::cuda::getDeviceProperties(0); + if (dev_prop->major < major || + (dev_prop->major == major && dev_prop->minor < minor)) { + return false; + } + return true; +} + +class NVFuserTest : public ::testing::Test { + protected: + void SetUp() override { + // requires PASCAL or newer + if (!deviceMajorMinorCheck(6)) { + GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; + } + } +}; + struct ValidationConstants { // Tolerances generated from randn + add + sum fusion // compared against double precision @@ -66,8 +86,8 @@ std::pair getTolerance( } else { // Reduction case size_t entry = 0; - while (sum_tolerance_entry[entry][0] < reduction_size && - entry < sum_tolerance_entry.size()) { + while (entry < sum_tolerance_entry.size() && + sum_tolerance_entry[entry][0] < reduction_size) { entry++; } double abs_tol = 0.0; @@ -221,7 +241,7 @@ class ReductionSizeMapper : private IterVisitor { } void handle(Expr* expr) override { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return; } diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 1d78a19c5ad..299c738c570 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -3,6 +3,10 @@ import unittest import os import random +import enum +import copy +from functools import reduce +import operator import torch from torch.nn import functional @@ -20,6 +24,8 @@ import numpy as np import math +from torch.autograd.gradcheck import gradcheck + from typing import List CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')) @@ -465,21 +471,25 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) def _unary_test_helper(self, operation, dtype, random_data): - shape = (4, 8, 32, 32) + gradient_check = (dtype == torch.float64) and random_data + shape = (8, 7) + torch.cuda.manual_seed_all(211) # need additional def of t for boolean ops def t(x: torch.Tensor, y: torch.Tensor): o = x * y + o = o + 5e-3 o = operation(o) return o - y = torch.tensor([1], device="cuda").to(dtype) + y = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check) + y = y.to(dtype=dtype) if random_data: - x = torch.randn(shape, dtype=torch.float32, device="cuda") + x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check) if dtype in self.int_types: # prefer a larger variance for integer types - x *= 5 + x = x * 5 x = x.to(dtype=dtype) else: x = self.special_values.to(dtype=dtype) @@ -491,14 +501,14 @@ def t(x: torch.Tensor, y: torch.Tensor): t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) - if dtype in self.support_tensor_dtypes: + jit_o = t_jit(x, y) + if gradient_check: + gradcheck(t_jit, [x, y], nondet_tol=1e-5) + elif dtype in self.support_tensor_dtypes: self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) - self.assertEqual(o, jit_o, msg=f""" - failing case: - {dtype} {operation} {x} - """) + self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2)) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -651,6 +661,16 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = o + z return o + def t_int(x: torch.Tensor, y: torch.Tensor): + o = operation(x, y) + o = 2 + o + return o + + def t_float(x: torch.Tensor, y: torch.Tensor): + o = operation(x, y) + o = 2. + o + return o + shape = (4, 32, 32) if random_data: x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1) @@ -665,14 +685,16 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): if operation in div_like and (dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64): y[y == 0] = 1 - o = t(x, y, z) - t_jit = torch.jit.script(t) - jit_o = t_jit(x, y, z) - jit_o = t_jit(x, y, z) + for test_fn in [t, t_int, t_float]: + o = t(x, y, z) + t_jit = torch.jit.script(t) + jit_o = t_jit(x, y, z) + jit_o = t_jit(x, y, z) + jit_o = t_jit(x, y, z) - self.assertEqual(o.dtype, jit_o.dtype) - self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) + self.assertEqual(o.dtype, jit_o.dtype) + self.assertEqual(o, jit_o) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -887,6 +909,21 @@ def test_ternary_ops_type_promotion(self): self._ternary_test_helper(op, dtypes, True) # random data self._ternary_test_helper(op, dtypes, False) # special numbers + # We can't test the scalar version of rsub from python + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") + def test_rsub(self): + x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") + y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") + + def rsub(x: torch.Tensor, y: torch.Tensor): + o = torch.rsub(x, y) + o = o * 2. + return o + + rsub_jit = torch.jit.script(rsub) + self._run_helper(rsub_jit, rsub, x, y) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") # legacy fuser does not work for rand_like, see issue #34361 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -1008,6 +1045,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_random_topo(self): os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1" self.assertTrue(runDefaultTestWithSeed(28449)) @@ -1272,7 +1311,6 @@ def forward(self, x: torch.Tensor): self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error)) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) - @unittest.skipIf(True, "codegen failure awaiting fix") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -1287,7 +1325,6 @@ def test_native_layer_norm(self): norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)] self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine) - @unittest.skipIf(True, "codegen failure awaiting fix") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -2306,7 +2343,7 @@ def t(x: torch.Tensor, y: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 - o = torch.matmul(o, y) + o = torch._C._nn.linear(o, y) return o x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True) @@ -2380,7 +2417,7 @@ def t(x: torch.Tensor, y: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 - o = torch.matmul(o, y) + o = torch._C._nn.linear(o, y) return o x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True) @@ -2731,8 +2768,8 @@ def __init__(self, num_features=10, affine=True, track_running_stats=True): track_running_stats=track_running_stats).to(dtype=dtype) def forward(self, x): - o = x * 2.0 - o = self.bn(o) + o = self.bn(x) + o = o * 2.0 return o x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_() @@ -3055,6 +3092,7 @@ def _run_fwd_helper(self, func, ops, *args): for op in ops: self.assertGraphContainsExactly(graph, op, 0) + @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -3068,7 +3106,7 @@ def t(x: torch.Tensor): o1 = x + 1.0 o2 = x * 0.5 return o1, o2 - self._run_fwd_helper(t, ['aten::add'], x) + self._run_fwd_helper(t, ['aten::add', 'aten::mul'], x) def t2(x: torch.Tensor, y: torch.Tensor): o1 = x.sum(0) @@ -3076,7 +3114,6 @@ def t2(x: torch.Tensor, y: torch.Tensor): return o1, o2 self._run_fwd_helper(t2, ['aten::sum', 'aten::mul'], x, y) - @unittest.skipIf(True, "Fixed in PR #68804") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -3120,6 +3157,343 @@ def t(x: torch.Tensor, y: torch.Tensor): graph = jitted.graph_for(x, y) self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) + def _bias_view_relu_helper(self, shape, output_shape, dtype, device, error): + class BiasViewRelu(torch.nn.Module): + def __init__(self): + super(BiasViewRelu, self).__init__() + self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) + with torch.no_grad(): + self.bias.fill_(10) + + def forward(self, inputs : torch.Tensor, view_shape : List[int]): + o = inputs + self.bias + o = o.view(view_shape) + return torch.relu(o) + + t = BiasViewRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + # profiling + jit_o = t_jit(x, output_shape) + # optimization + jit_o = t_jit(x, output_shape) + # final + jit_o = t_jit(x, output_shape) + # eager - baseline + o = t(x, output_shape) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x, output_shape) + + has_inferred_dimension = any([dim == -1 for dim in output_shape]) + if has_inferred_dimension: + # prohibit fusing when view_shape contains an inferred dimension + self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) + self.assertGraphContainsExactly(graph, 'prim::view_copy', 0) + else: + self.assertGraphContains(graph, FUSION_GUARD) + self.assertGraphContains(graph, 'prim::view_copy', True) + + def _alias_bias_view_relu_helper(self, shape, output_shape, dtype, device, error): + class BiasViewRelu(torch.nn.Module): + def __init__(self): + super(BiasViewRelu, self).__init__() + self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) + with torch.no_grad(): + self.bias.fill_(10) + + def forward(self, inputs : torch.Tensor, view_shape : List[int]): + o = inputs.view(view_shape) + inputs = inputs * self.bias + return torch.relu(o) + + t = BiasViewRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + # profiling + jit_o = t_jit(x, output_shape) + # optimization + jit_o = t_jit(x, output_shape) + # final + jit_o = t_jit(x, output_shape) + # eager - baseline + o = t(x, output_shape) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x, output_shape) + self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) + self.assertGraphContainsExactly(graph, 'prim::view_copy', 0) + + # generate random view given original view + def _random_view(self, original_view, max_len=8, max_views=10000): + class Moves(enum.Enum): + Merge = 0 + Split = 1 + Broadcast = 2 + ImplicitBroadcast = 3 + Keep = 4 + + def valid(old_view, new_view): + old_view_size = reduce(operator.mul, old_view) + new_view_size = reduce(operator.mul, new_view) + return old_view_size == new_view_size + + # given a random starting number, find the nearest divisor + def find_nearest_divisor(N): + if 2 >= (N - 1): + return -1 + result = random.randint(2, N - 1) + while (N % result) != 0: + result += 1 + return result + + complete_views = set([tuple(original_view)]) + + to_visit = [] + # empty new view, curent originaal view, start pos=0, move count = 0, last_move + to_visit.append(([], original_view, 0, [], Moves.Keep)) + + # depth-first search of view shapes, starting from the original view + while len(to_visit) > 0 and len(complete_views) < max_views: + new_view, old_view, odx, move_list, last_move = to_visit[-1] + to_visit.pop() + + # iterate over each move type + for idx in range(len(Moves)): + state = Moves(idx) + new_view_clone = copy.deepcopy(new_view) + old_view_clone = copy.deepcopy(old_view) + new_move_list = move_list + [state] + new_odx = odx + + # Update state using Move state + if state == Moves.Keep: + new_size = old_view_clone[odx] + new_view_clone.append(new_size) + new_odx += 1 + + elif state == Moves.Merge: + if odx + 1 < len(old_view_clone): + new_size = old_view_clone[odx] * old_view_clone[odx + 1] + new_view_clone.append(new_size) + new_odx += 2 + else: + continue + + elif state == Moves.Broadcast and last_move != Moves.Broadcast: + new_view_clone.append(1) + + elif state == Moves.Split: + new_size = find_nearest_divisor(old_view_clone[odx]) + if new_size == -1: + continue + new_view_clone.append(new_size) + old_view_clone[odx] = int(old_view[odx] / new_size) + + if old_view_clone[odx] == 1: + new_odx += 1 + + elif state == Moves.ImplicitBroadcast: + old_view_clone.insert(odx + 1, 1) + new_size = old_view[odx] * 1 + new_view_clone.append(new_size) + new_odx += 2 + + if new_odx < len(old_view_clone) and len(new_move_list) < max_len: + to_visit.append((new_view_clone, old_view_clone, new_odx, new_move_list, state)) + elif (valid(original_view, new_view_clone)): + final_new_view = tuple(new_view_clone) + complete_views.add(final_new_view) + return list(complete_views) + + # ndims - number of dimensions + # test_fn - view test function + def _view_test_generator(self, ndims, test_fn): + # create random tensor + # max value for each dimension + max_size = 10e7 + max_value = max(int(pow(max_size, 1. / ndims)), 1) + sizes = [random.randint(1, max_value) for idx in range(ndims)] + x = torch.randn(sizes) + + original_sizes = list(x.size()) + all_views = self._random_view(original_sizes) + random.shuffle(all_views) + + max_samples = 20 + max_views = min(len(all_views), max_samples) + total = 0 + correct = 0 + # test random combinations of compatible views + for idx in range(max_views): + for jdx in range(idx + 1, max_views): + total += 1 + test_fn(all_views[idx], all_views[jdx], torch.float, 'cuda', 1e-6) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_view(self): + torch._C._jit_set_nvfuser_guard_mode(True) + self._bias_view_relu_helper([2, 3, 4, 5], [-1, 4, 5], torch.float, 'cuda', 1e-6) + for ndims in range(1, 5): + self._view_test_generator(ndims, self._bias_view_relu_helper) + self._alias_bias_view_relu_helper([2, 3, 4, 5], [1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) + + def _bias_squeeze_relu_helper(self, shape, dtype, device, error): + class BiasSqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasSqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = inputs + bias + o = torch.squeeze(o) + return torch.relu(o) + + t = BiasSqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x) + self.assertGraphContains(graph, FUSION_GUARD) + self.assertGraphContains(graph, 'prim::squeeze_copy', True) + + def _alias_bias_squeeze_relu_helper(self, shape, dtype, device, error): + class BiasSqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasSqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = torch.squeeze(inputs) + inputs = inputs * bias + return torch.relu(o) + + t = BiasSqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x, bias) + self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) + self.assertGraphContainsExactly(graph, 'prim::squeeze_copy', 0) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_squeeze(self): + self._bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) + self._alias_bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) + + def _bias_unsqueeze_relu_helper(self, shape, dtype, device, error): + class BiasUnsqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasUnsqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = inputs + bias + o = torch.unsqueeze(o, 0) + return torch.relu(o) + + t = BiasUnsqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x) + self.assertGraphContains(graph, FUSION_GUARD) + self.assertGraphContains(graph, 'prim::unsqueeze_copy', True) + + def _alias_bias_unsqueeze_relu_helper(self, shape, dtype, device, error): + class BiasUnsqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasUnsqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = torch.squeeze(inputs) + o = torch.unsqueeze(inputs, 0) + inputs = inputs * bias + return torch.relu(o) + + t = BiasUnsqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x) + self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) + self.assertGraphContainsExactly(graph, 'prim::unsqueeze_copy', 0) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_unsqueeze(self): + self._bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6) + self._alias_bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_alias_pass_fix(self): + x = torch.randn(4, 24, 2, 2, dtype=torch.float, device="cuda") + w = torch.randn(24, 24, 1, 1, dtype=torch.float, device="cuda") + b = torch.randn(24, dtype=torch.float, device="cuda") + + def t(x, w, b): + b2 = b + 1.0 + o = torch.conv2d(x, w, b2) + return o + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, w, b) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_squeeze_negative_dim(self): + x = torch.randn(4, 24, 1, 2, dtype=torch.float, device="cuda") + + def t(x): + o = x + 1.0 + o = o.squeeze(-2) + o = o * 2.0 + return o + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -3154,6 +3528,138 @@ def t(x, y, s): # sibling fusion should be disabled with the flag self.assertGraphContainsExactly(t_jit.graph_for(x, y, s), FUSION_GUARD, 0) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_build_shape_expression_native_dropout(self): + x = torch.randn(4, 2, device="cuda") + + def t(x): + o, mask = torch.native_dropout(x, 0.0, True) + o1 = o.sigmoid() + o2 = mask.float().sigmoid() + return (o1, o2) + + t_jit = torch.jit.script(t) + + jit_o = t_jit(x) + jit_o = t_jit(x) + o = t(x) + for oo, jit_oo in zip(o, jit_o): + self.assertEqual(oo.dtype, jit_oo.dtype) + self.assertEqual(oo, jit_oo) + self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_scalar_tensor_permuted(self): + x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0]) + y = torch.tensor(1.0, device="cuda") + + with nvfuser_singleton_fusion(True): + def t(x, y): + return x + y + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_cpu_scalar(self): + x = torch.randn(4, 2, 3, device="cuda") + y = torch.tensor(1.0, device="cpu") + z = torch.tensor(2.0, device="cpu") + + with nvfuser_singleton_fusion(True): + # testing cpu scalar tensor promotion + def t(x, y): + return x + y + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y) + + # scalar cpu tensor add should NOT be fused + @torch.jit.script + def t1(y, z): + return y * z + for _ in range(5): + t1(y, z) + self.assertGraphContainsExactly(t1.graph_for(y, z), FUSION_GUARD, 0) + + # everything, including scalar cpu tensor add should be fused + @torch.jit.script + def t2(x, y, z): + tmp = y + z + return tmp + x + for _ in range(5): + t2(x, y, z) + self.assertGraphContainsExactly(t2.graph_for(x, y, z), 'aten::add', 0) + self.assertGraphContainsExactly(t2.graph_for(x, y, z), FUSION_GUARD, 1) + + # 'cpu_tmp = y + z' shouldn't be fused. + @torch.jit.script + def t3(x, y, z): + cpu_tmp = y + z + out = x + y + return cpu_tmp, out + for _ in range(5): + t3(x, y, z) + self.assertGraphContainsExactly(t3.graph_for(x, y, z), FUSION_GUARD, 1) + self.assertGraphContainsExactly(t3.graph_for(x, y, z), 'aten::add', 1) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_shape_expression(self): + x = torch.randn(4, 2, 1, 3, device="cuda") + + def t_unsqueeze(x): + t0 = x.relu() + t1 = t0.unsqueeze(1) + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + def t_squeeze(x): + t0 = x.relu() + t1 = t0.squeeze() + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + def t_squeeze_dim(x): + t0 = x.relu() + t1 = t0.squeeze(-2) + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + # squeezing a non-size 1 dimension should be a no op + def t_squeeze_dim_no_op(x): + t0 = x.relu() + t1 = t0.squeeze(1) + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + def run(fn): + jit_fn = torch.jit.script(fn) + jit_o = jit_fn(x) + jit_o = jit_fn(x) + jit_o = jit_fn(x) + o = fn(x) + # output 0 is a tensor, so we check dtype and value + self.assertEqual(o[0].dtype, jit_o[0].dtype) + self.assertEqual(o[0], jit_o[0]) + # output 1 is shape + self.assertEqual(o[1], jit_o[1]) + self.assertGraphContainsExactly(jit_fn.graph_for(x), FUSION_GUARD, 1) + + for t in [t_unsqueeze, t_squeeze, t_squeeze_dim, t_squeeze_dim_no_op]: + run(t) + class TestPassManagerCudaFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "requires CUDA") diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index da78a7ceb4b..f63e4ea1668 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -635,7 +635,9 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp", "torch/csrc/jit/codegen/cuda/instrumentation.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", + "torch/csrc/jit/codegen/cuda/ir_builder.cpp", "torch/csrc/jit/codegen/cuda/ir_cloner.cpp", + "torch/csrc/jit/codegen/cuda/ir_container.cpp", "torch/csrc/jit/codegen/cuda/ir_graphviz.cpp", "torch/csrc/jit/codegen/cuda/ir_nodes.cpp", "torch/csrc/jit/codegen/cuda/ir_iostream.cpp", @@ -645,28 +647,32 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/kernel_cache.cpp", "torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp", "torch/csrc/jit/codegen/cuda/kernel_ir.cpp", - "torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp", - "torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp", + "torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp", "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp", - "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower_allocation.cpp", + "torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp", "torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp", + "torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp", "torch/csrc/jit/codegen/cuda/lower_index.cpp", "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp", "torch/csrc/jit/codegen/cuda/lower_loops.cpp", "torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp", "torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp", "torch/csrc/jit/codegen/cuda/lower_predicate.cpp", + "torch/csrc/jit/codegen/cuda/lower_replace_size.cpp", "torch/csrc/jit/codegen/cuda/lower_shift.cpp", "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp", + "torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp", "torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp", "torch/csrc/jit/codegen/cuda/lower_unroll.cpp", "torch/csrc/jit/codegen/cuda/lower_utils.cpp", "torch/csrc/jit/codegen/cuda/lower_validation.cpp", + "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", "torch/csrc/jit/codegen/cuda/non_divisible_split.cpp", + "torch/csrc/jit/codegen/cuda/ops/alias.cpp", "torch/csrc/jit/codegen/cuda/ops/composite.cpp", "torch/csrc/jit/codegen/cuda/ops/normalization.cpp", "torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp", diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 2c9925cf893..d9bf46b51c7 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -23,14 +24,15 @@ Val* newScalar(ValType vtype, DataType dtype) { case (ValType::Scalar): switch (dtype) { case DataType::Bool: - return new Bool(); + return IrBuilder::create(); case DataType::Double: case DataType::Float: case DataType::Half: case DataType::BFloat16: - return new Double(); + return IrBuilder::create(); + case DataType::Int32: case DataType::Int: - return new Int(); + return IrBuilder::create(); default: break; } @@ -103,10 +105,10 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { } for (const auto dim_i : c10::irange(out_domain.size())) { if (extent_vals[dim_i] != nullptr) { - out_domain[dim_i] = new IterDomain( - new Int(start_offsets[dim_i]), + out_domain[dim_i] = IrBuilder::create( + IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i], - new Int(stop_offsets[dim_i]), + IrBuilder::create(stop_offsets[dim_i]), ParallelType::Serial, iter_types[dim_i]); } else { @@ -121,13 +123,17 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { break; } } - out_domain[dim_i] = - new IterDomain(new Int(0), new Int(1), ParallelType::Serial, itype); + out_domain[dim_i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), + ParallelType::Serial, + itype); } } - return new TensorView( - new TensorDomain(out_domain, std::vector(out_domain.size(), true)), + return IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), dtype); } @@ -195,7 +201,7 @@ Val* castOp(DataType dtype, Val* v1) { } Val* out = newValLike(v1, dtype); - new UnaryOp(UnaryOpType::Cast, out, v1); + IrBuilder::create(UnaryOpType::Cast, out, v1); return out; } @@ -219,7 +225,7 @@ Val* unaryOp(UnaryOpType type, Val* v1) { // } Val* out = newValLike(v1, v1->getDataType().value()); - new UnaryOp(type, out, v1); + IrBuilder::create(type, out, v1); return out; } @@ -379,7 +385,7 @@ Val* binaryOp(BinaryOpType type, Val* v1, Val* v2, DataType common_dtype) { } else { out = newScalar(out_vtype, out_dtype); } - new BinaryOp(type, out, vals[0], vals[1]); + IrBuilder::create(type, out, vals[0], vals[1]); return out; } @@ -589,7 +595,7 @@ static TensorView* newForReduction( " of tensor ", tv); - new_domain.push_back(new IterDomain( + new_domain.push_back(IrBuilder::create( id->start(), id->extent(), id->stopOffset(), @@ -597,12 +603,12 @@ static TensorView* newForReduction( isReduction ? IterType::Reduction : id->getIterType())); } - TensorDomain* td = - new TensorDomain(new_domain, std::vector(new_domain.size(), true)); + TensorDomain* td = IrBuilder::create( + new_domain, std::vector(new_domain.size(), true)); data_type = data_type == DataType::Null ? tv->getDataType().value() : data_type; - return new TensorView(td, data_type); + return IrBuilder::create(td, data_type); } TensorView* reductionOp( @@ -652,7 +658,7 @@ TensorView* reductionOp( out_type, " and ", init_type); - new ReductionOp(reduction_op_type, init, out, tv); + IrBuilder::create(reduction_op_type, init, out, tv); if (keep_dim) { auto tv_root = TensorDomain::noReductions(tv->getRootDomain()); @@ -673,9 +679,9 @@ TensorView* sum( Val* init = nullptr; auto dtype = v1->getDataType().value(); if (isFloatingPointType(dtype)) { - init = new Double(0.0); + init = IrBuilder::create(0.0); } else if (isIntegralType(dtype)) { - init = new Int(0); + init = FusionGuard::getCurFusion()->zeroVal(); } else { TORCH_CHECK( false, @@ -693,13 +699,13 @@ TensorView* max( Val* init = nullptr; switch (v1->getDataType().value()) { case (DataType::Double): - init = new Double(std::numeric_limits::lowest()); + init = IrBuilder::create(std::numeric_limits::lowest()); break; case (DataType::Float): - init = new Double(std::numeric_limits::lowest()); + init = IrBuilder::create(std::numeric_limits::lowest()); break; case (DataType::Int): - init = new Int(INT_MIN); + init = IrBuilder::create(INT_MIN); break; default: TORCH_CHECK( @@ -718,13 +724,13 @@ TensorView* min( Val* init = nullptr; switch (v1->getDataType().value()) { case (DataType::Double): - init = new Double(DBL_MAX); + init = IrBuilder::create(DBL_MAX); break; case (DataType::Float): - init = new Double(FLT_MAX); + init = IrBuilder::create(FLT_MAX); break; case (DataType::Int): - init = new Int(INT_MAX); + init = IrBuilder::create(INT_MAX); break; default: TORCH_CHECK( @@ -767,9 +773,9 @@ TensorView* broadcast( size_t iinp = 0, ibdim = 0; while (ibdim < is_broadcast_dim.size()) { if (is_broadcast_dim[ibdim]) { - out_domain.push_back(new IterDomain( - new Int(0), - new Int(1), + out_domain.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), ParallelType::Serial, IterType::BroadcastWithoutStride)); } else { @@ -779,10 +785,11 @@ TensorView* broadcast( ibdim++; } - TensorView* out_tensor = new TensorView( - new TensorDomain(out_domain, std::vector(out_domain.size(), true)), + TensorView* out_tensor = IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); - new BroadcastOp(out_tensor, inp, is_broadcast_dim); + IrBuilder::create(out_tensor, inp, is_broadcast_dim); return out_tensor; } @@ -799,6 +806,10 @@ WelfordResult Welford( TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor"); TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); + if (init_N == nullptr) { + init_N = FusionGuard::getCurFusion()->zeroVal(); + } + // Initial values for welford op are tensors, so their dims have to match the // output dim, // i.e. original_dims - dims_to_be_reduced @@ -819,8 +830,8 @@ WelfordResult Welford( init_avg_val = init_avg; init_var_val = init_var; } else { - init_avg_val = new Double(0); - init_var_val = new Double(0); + init_avg_val = IrBuilder::create(0); + init_var_val = IrBuilder::create(0); } // Check and collect reduction axes @@ -847,7 +858,7 @@ WelfordResult Welford( TensorView* out_var = newForReduction(tv, uint_axes); TensorView* out_N = newForReduction(tv, uint_axes, DataType::Int); - new WelfordOp( + IrBuilder::create( out_avg, out_var, out_N, /*out var/avg/count */ @@ -856,7 +867,7 @@ WelfordResult Welford( init_N, /*init var/avg/count */ tv, nullptr, - new Int(1)); /*in var/avg/count */ + FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */ return WelfordResult(out_avg, out_var, out_N); } @@ -888,10 +899,11 @@ TensorView* transpose( out_domain[i] = in_id->clone(); } - TensorView* out_tensor = new TensorView( - new TensorDomain(out_domain, std::vector(out_domain.size(), true)), + TensorView* out_tensor = IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); - new TransposeOp(out_tensor, inp, new2old); + IrBuilder::create(out_tensor, inp, new2old); return out_tensor; } @@ -938,7 +950,7 @@ TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } // lerp -TORCH_CUDA_CU_API Val* lerp(Val* start, Val* end, Val* weight) { +Val* lerp(Val* start, Val* end, Val* weight) { auto vals = maybeBroadcast({start, end, weight}); Val* intrm1 = sub(vals[1], vals[0]); Val* intrm2 = mul(vals[2], intrm1); @@ -1024,7 +1036,8 @@ Val* where(Val* c, Val* v1, Val* v2) { } else { out = newScalar(out_vtype, out_dtype); } - new TernaryOp(TernaryOpType::Where, out, vals[0], vals[1], vals[2]); + IrBuilder::create( + TernaryOpType::Where, out, vals[0], vals[1], vals[2]); return out; } @@ -1064,7 +1077,8 @@ Val* threshold(Val* in, Val* thresh, Val* value) { value = optionalCast(in->getDataType().value(), value); Val* out = newValLike(in, in->getDataType().value()); - new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value); + IrBuilder::create( + TernaryOpType::Threshold, out, in, thresh, value); return out; } @@ -1084,7 +1098,7 @@ Val* clamp(Val* in, Val* min_val, Val* max_val) { max_val = optionalCast(in->getDataType().value(), max_val); Val* out = newValLike(in, in->getDataType().value()); - new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val); + IrBuilder::create(TernaryOpType::Clamp, out, in, min_val, max_val); return out; } @@ -1186,125 +1200,157 @@ TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { } TensorView* shift(TensorView* inp, const std::vector& offsets, bool pad) { + // When pad is false, no padding is given. When it is true, padding + // sizes are set so that output domains have the same extents as + // input domains. + std::vector pad_width(offsets.size(), 0); + if (pad) { + for (const auto i : c10::irange(offsets.size())) { + pad_width[i] = std::abs(offsets[i]); + } + } + return shift(inp, offsets, pad_width); +} + +TensorView* shift( + TensorView* inp, + const std::vector& offsets, + const std::vector& pad_width_param) { + auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); + const auto ndims = inp_dom.size(); + + auto pad_width = pad_width_param; + // Default padding is set so that the extent is kept unchanged + if (pad_width.empty()) { + pad_width = offsets; + for (auto& p : pad_width) { + p = std::abs(p); + } + } + TORCH_CHECK( - TensorDomain::noReductions(inp->getRootDomain()).size() == offsets.size(), + ndims == offsets.size(), "Invalid shift offsets, number of entries in offsets expected to be ", - TensorDomain::noReductions(inp->getRootDomain()).size(), + ndims, " but received ", offsets.size()); + TORCH_CHECK( + ndims == pad_width.size(), + "Invalid padding width list, number of entries in pad_width expected to be ", + ndims, + " but received ", + pad_width.size()); + + std::for_each(pad_width.begin(), pad_width.end(), [](const auto& pad) { + TORCH_CHECK(pad >= 0, "Padding width must be >= 0: ", pad); + }); + TensorView* out = nullptr; - if (pad) { - out = newValLike(inp, inp->getDataType().value())->as(); - } else { - auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); - const auto ndims = inp_dom.size(); - std::vector out_dom; - for (const auto i : c10::irange(ndims)) { - const auto inp_axis = inp_dom[i]; - const auto offset = offsets[i]; - if (offset == 0) { - out_dom.push_back(inp_axis->clone()); - continue; - } + std::vector out_dom; + for (const auto i : c10::irange(ndims)) { + const auto inp_axis = inp_dom[i]; + const auto offset = offsets[i]; + const auto pad = pad_width[i]; - Int* current_start_offset = dynamic_cast(inp_axis->start()); - TORCH_INTERNAL_ASSERT( - current_start_offset != nullptr && current_start_offset->isConst(), - "Invalid IterDomain start value:", - current_start_offset); + if (offset == 0) { + out_dom.push_back(inp_axis->clone()); + continue; + } - Int* current_stop_offset = dynamic_cast(inp_axis->stopOffset()); - TORCH_INTERNAL_ASSERT( - current_stop_offset != nullptr && current_stop_offset->isConst(), - "Invalid IterDomain stop offset value:", - current_stop_offset); - - const auto cur_start_offset_value = current_start_offset->value().value(); - const auto cur_stop_offset_value = current_stop_offset->value().value(); - - Val* out_start_offset = nullptr; - Val* out_stop_offset = nullptr; - - if (offset > 0) { - // shift to right; extent remains the same, start and stop - // positions are moved right - out_start_offset = new Int(cur_start_offset_value + offset); - out_stop_offset = - new Int(std::max(cur_stop_offset_value - offset, int64_t(0))); - } else { - // shift to left; extent remains the same, start and stop - // positions are moved left - out_start_offset = - new Int(std::max(cur_start_offset_value + offset, int64_t(0))); - out_stop_offset = new Int(cur_stop_offset_value - offset); - } + Int* current_start_offset = dynamic_cast(inp_axis->start()); + TORCH_INTERNAL_ASSERT( + current_start_offset != nullptr && current_start_offset->isConst(), + "Invalid IterDomain start value:", + current_start_offset); - out_dom.push_back(new IterDomain( - out_start_offset, - inp_axis->extent(), - out_stop_offset, - ParallelType::Serial, - inp_axis->getIterType())); + Int* current_stop_offset = dynamic_cast(inp_axis->stopOffset()); + TORCH_INTERNAL_ASSERT( + current_stop_offset != nullptr && current_stop_offset->isConst(), + "Invalid IterDomain stop offset value:", + current_stop_offset); + + const auto cur_start_offset_value = current_start_offset->value().value(); + const auto cur_stop_offset_value = current_stop_offset->value().value(); + + int64_t out_start_offset = 0; + int64_t out_stop_offset = 0; + + if (offset > 0) { + // shift to right; extent remains the same, start and stop + // positions are moved right + out_start_offset = cur_start_offset_value + offset - pad; + out_stop_offset = std::max(cur_stop_offset_value - offset, int64_t(0)); + // If pad > offset, the extent of the output ID could be larger than the + // input, and the start offset of the output domain could become + // negative, which is not supported. + TORCH_CHECK( + out_start_offset >= 0, + "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ", + pad, + ". Shift: ", + offset, + "."); + } else { + // shift to left; extent remains the same, start and stop + // positions are moved left + out_start_offset = std::max(cur_start_offset_value + offset, int64_t(0)); + out_stop_offset = cur_stop_offset_value - offset - pad; + // Similar to the above case whwere offset is positive, if pad > + // -offset (note offset is negative), the extent of the output + // ID could be larger than the input, and the stop offset of the + // output domain could become negative. + TORCH_CHECK( + out_stop_offset >= 0, + "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ", + pad, + ". Shift: ", + offset, + "."); } - out = new TensorView( - new TensorDomain(out_dom, std::vector(out_dom.size(), true)), - inp->getDataType().value()); + out_dom.push_back(IrBuilder::create( + IrBuilder::create(out_start_offset), + inp_axis->extent(), + IrBuilder::create(out_stop_offset), + ParallelType::Serial, + inp_axis->getIterType())); } - new ShiftOp(out, inp, offsets, pad); - return out; -} - -namespace { -std::vector convertToIntVector(const std::vector& x) { - std::vector converted; - std::transform(x.begin(), x.end(), std::back_inserter(converted), [](int x) { - return new Int(x); - }); - return converted; -} -} // namespace + out = IrBuilder::create( + IrBuilder::create( + out_dom, std::vector(out_dom.size(), true)), + inp->getDataType().value()); -TensorView* gather( - TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides) { - std::vector window_shape_int = convertToIntVector(window_shape); - std::vector> pad_width_int; - std::transform( - pad_width.begin(), - pad_width.end(), - std::back_inserter(pad_width_int), - [](const std::vector& x) { return convertToIntVector(x); }); - return gather(inp, window_shape_int, pad_width_int, strides); + IrBuilder::create(out, inp, offsets, pad_width); + return out; } namespace { -// Return a new TensorDomain with given root domains. Apply strides if -// necessary. With non-unit strides, strided domains become an rfactor -// domain. +// Return a new TensorDomain with given root domains. Apply +// strides if necessary. With non-unit strides, strided domains become an +// rfactor domain. TensorDomain* generateTensorDomainWithStrides( const std::vector& root_domains, - const std::vector& strides) { + const std::vector& strides, + bool skip_unit_stride) { std::vector strided_domains; // If strides are just unit strides, don't apply striding - if (strides.empty() || std::all_of(strides.begin(), strides.end(), [](int s) { - return s == 1; - })) { - return new TensorDomain( + if (strides.empty() || + (skip_unit_stride && + std::all_of( + strides.begin(), strides.end(), [](int s) { return s == 1; }))) { + return IrBuilder::create( root_domains, std::vector(root_domains.size(), true)); } for (const auto i : c10::irange(root_domains.size())) { auto root_dom = root_domains.at(i); - if (i >= strides.size() || strides[i] == 1) { + if (i >= strides.size() || (skip_unit_stride && strides[i] == 1)) { strided_domains.push_back(root_dom); continue; } @@ -1317,7 +1363,7 @@ TensorDomain* generateTensorDomainWithStrides( auto contig_vector_size = strided_domains.size(); - auto strided_td = new TensorDomain( + auto strided_td = IrBuilder::create( root_domains, strided_domains, strided_domains, @@ -1330,9 +1376,10 @@ TensorDomain* generateTensorDomainWithStrides( TensorView* gather( TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides) { + const std::vector& window_shape, + const std::vector>& pad_width, + const std::vector& strides, + bool trim_out_of_bounds) { auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); const auto ndims = inp_dom.size(); @@ -1343,6 +1390,10 @@ TensorView* gather( " but received ", window_shape.size()); + std::for_each(window_shape.begin(), window_shape.end(), [](const auto& w) { + TORCH_CHECK(w > 0, "Window size must be > 0: ", w); + }); + TORCH_CHECK( ndims == pad_width.size(), "Invalid pad width: number of entries expected to be ", @@ -1354,6 +1405,10 @@ TensorView* gather( TORCH_CHECK( p.size() == 2, "Each entry of pad_width must have two non-negative integers."); + std::for_each(p.begin(), p.end(), [](const auto& p_left_or_right) { + TORCH_CHECK( + p_left_or_right >= 0, "Padding must be >= 0: ", p_left_or_right); + }); }); TORCH_CHECK( @@ -1363,6 +1418,10 @@ TensorView* gather( " but received ", strides.size()); + std::for_each(strides.begin(), strides.end(), [](const auto& s) { + TORCH_CHECK(s > 0, "Stride must be > 0: ", s); + }); + std::vector out_root_domains; std::vector out_gather_dom; @@ -1371,40 +1430,57 @@ TensorView* gather( const auto window_dim = window_shape[i]; const auto pad_left = pad_width[i][0]; const auto pad_right = pad_width[i][1]; + // This may be over-conservative TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt()); + const auto inp_stop_offset = inp_axis->stopOffset()->getInt(); + TORCH_INTERNAL_ASSERT( + inp_stop_offset.has_value(), + "Dynamic stop offset not supported: ", + inp_axis); + const auto extent_adjustment = window_dim - 1 - pad_left - pad_right; + TORCH_CHECK( + extent_adjustment >= 0, + "Invalid gather window and padding as output extent would be larger than input.", + " Window: ", + window_dim, + ". Padding left: ", + pad_left, + ". Padding right: ", + pad_right); + const auto out_stop_offset = inp_stop_offset.value() + extent_adjustment; Val* out_axis_dim = nullptr; - if (window_dim->isConst() && pad_left->isConst() && pad_right->isConst()) { - const int64_t extent_adjustment = - -(-window_dim->value().value() + 1 + pad_left->value().value() + - pad_right->value().value()); - out_axis_dim = extent_adjustment == 0 - ? inp_axis->extent() - : sub(inp_axis->extent(), new Int(extent_adjustment)); - } else { - out_axis_dim = - add(add(sub(inp_axis->extent(), window_dim), new Int(1)), - add(pad_left, pad_right)); - } - // TODO: out_axis_dim is assumed to be the same as the extent of - // the input domain. Throw an error if it isn't the case. - out_root_domains.push_back(new IterDomain( - new Int(0), - out_axis_dim, + out_root_domains.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + inp_axis->extent(), + IrBuilder::create(out_stop_offset), ParallelType::Serial, inp_axis->getIterType())); // create a new axis for the gathered domain - out_gather_dom.push_back(new IterDomain( - new Int(0), window_dim, ParallelType::Serial, IterType::Gather)); + out_gather_dom.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create(window_dim), + ParallelType::Serial, + IterType::Gather)); } out_root_domains.insert( out_root_domains.end(), out_gather_dom.begin(), out_gather_dom.end()); - auto out_td = generateTensorDomainWithStrides(out_root_domains, strides); + TensorDomain* out_td = nullptr; + + if (trim_out_of_bounds) { + // If no stride vector is given, just use stride 1. It does not do + // any striding effect, but out-of-bounds values are trimmed. + auto s = strides.empty() ? std::vector(ndims, 1) : strides; + out_td = generateTensorDomainWithStrides(out_root_domains, strides, false); + } else { + out_td = generateTensorDomainWithStrides(out_root_domains, strides, true); + } - auto out_tv = new TensorView(out_td, inp->getDataType().value()); + auto out_tv = + IrBuilder::create(out_td, inp->getDataType().value()); - new GatherOp(out_tv, inp, window_shape, pad_width); + IrBuilder::create(out_tv, inp, window_shape, pad_width); return out_tv; } diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 5652d68eab8..1f18f65666a 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -114,7 +114,9 @@ TORCH_CUDA_CU_API WelfordResult Welford( const std::vector& axes, TensorView* init_avg = nullptr, TensorView* init_var = nullptr, - Int* init_N = new Int(0)); + // Initializes to 0 in function definition, doing this so we don't have to + // import IrBuilder just for this one interface. + Int* init_N = nullptr); // UNARY OPERATIONS // abs @@ -484,19 +486,27 @@ TORCH_CUDA_CU_API TensorView* sum_to( //! t1[i, j] = 0, otherwise //! //! The pad option controls how out-of-boundary accesses are -//! handled. When pad is true, shifting works as if the source tensor -//! is padded by zero. Otherwise, it does not modify the output tensor -//! region whose source coordinates are out-of-boundry. In both cases, -//! the size of output tensor does not change. However, when pad is -//! false, the start or stop value of the shifted axis is adjusted -//! accordingly. For example, when a shift offset is one, the axis start -//! value would be incremented by one. +//! handled. It specifies how many zeros are logically padded. If no +//! pad option is given, it automatically pads the input tensor so +//! that the output tensor has the same extent for each axis. //! -//! \param pad If true, out-of-boundary access returns zero. +//! When a padding value is smaller than the absolute value of a shift +//! offset, the output axis still has the same extent but its start or +//! stop offset is moved inward to signify those outside of the offset +//! are invalid. +//! +//! It is not allowed to use padding values that are larger than shift +//! offsets, which would mean output extentes would be larger than +//! input extents TORCH_CUDA_CU_API TensorView* shift( TensorView* inp, const std::vector& offsets, - bool pad = true); + const std::vector& pad_width = {}); + +TORCH_CUDA_CU_API TensorView* shift( + TensorView* inp, + const std::vector& offsets, + bool pad); //! Gather a window of nearby elements for each element. //! @@ -508,8 +518,13 @@ TORCH_CUDA_CU_API TensorView* shift( //! implemented with strided split, whose outer output domain becomes //! the root domain for subsequent consumers. The inner output domain //! becomes a Stride domain, which is ignored by subsequent consumers. +//! Only valid input ranges are fed into strided splits. //! -//! Example: +//! When trim_out_of_bounds is true, the values at the first and last +//! ends that are outside of the start and stop offsets are +//! effetively trimmed by partial split by 1. +//! +//! Example 1: //! t0: 2D tensor of [N, M] //! t1 = gather(t0, {1, 3}, {{0, 0}, {1, 1}}); //! @@ -517,23 +532,34 @@ TORCH_CUDA_CU_API TensorView* shift( //! t1: [N, M, 1, 3] //! t1[i, j, k, l] = The value at the window position of [k, l] //! for t0[i, j] -TORCH_CUDA_CU_API TensorView* gather( - TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides = {}); - -//! Gather a window of nearby elements for each element. //! -//! Same as the another gather interface but with Int* parameters. +//! Example 2.1 (without trimming): +//! t0: 2D tensor of [N, M] +//! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}); +//! +//! then: +//! t1: [N (stop offset: 1), M (stop offset: 1, 2, 2)] +//! +//! Example 2.1 (with trimming) +//! t0: 2D tensor of [N, M] +//! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}, true); +//! +//! then: +//! t1: [ceilDiv(N - 1, 1), ceilDiv(M - 1, 1), 2, 2] +//! +//! Example 3: +//! t0: 2D tensor of [N, M] +//! t1 = gather(t0, {3, 3}, {{0, 0}, {0, 0}}, {3, 3}); +//! +//! then: +//! t1: [ceilDiv(N - 2, 3), ceilDiv(M - 2, 3), 2, 2] //! -//! TODO: Remove this interface as we do not intend to support dynamic -//! window shapes at this moment. TORCH_CUDA_CU_API TensorView* gather( TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides = {}); + const std::vector& window_shape, + const std::vector>& pad_width, + const std::vector& strides = {}, + bool trim_out_of_bounds = false); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 709c810efe3..67926e92672 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -19,7 +20,7 @@ namespace codegen { namespace { -class CudaKernelGenerator : private kir::IrVisitor { +class CudaKernelGenerator : private OptOutConstDispatch { static constexpr const char* kTab = " "; public: @@ -45,7 +46,7 @@ class CudaKernelGenerator : private kir::IrVisitor { code_ << "__global__ void " << kernel_name << "("; - std::vector params; + std::vector params; // Inputs & Outputs for (auto val : kernel_->inputs()) { @@ -56,13 +57,16 @@ class CudaKernelGenerator : private kir::IrVisitor { } // Generate parameter declarations - for (kir::Val* val : params) { - if (const auto tv = dynamic_cast(val)) { - code_ << "Tensor<" << val->dtype() << ", " - << TensorDomain::noReductions( - tv->fuserTv()->getMaybeRFactorDomain()) - .size() + for (Val* val : params) { + if (const auto tv = dynamic_cast(val)) { + if (tv->isCpuScalar()) { + code_ << " CpuScalarTensor<" << val->dtype() << "> " << varName(tv); + } else { + code_ + << "Tensor<" << val->dtype() << ", " + << TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size() << "> " << varName(tv); + } } else { TORCH_INTERNAL_ASSERT(val->isScalar()); // NOLINT (LLVM bug 48525) TORCH_INTERNAL_ASSERT(val->definition() == nullptr); @@ -76,17 +80,17 @@ class CudaKernelGenerator : private kir::IrVisitor { // Global buffers for (auto allocate : kernel_summary.global_allocations) { - TORCH_INTERNAL_ASSERT(allocate->buffer()->isA()); - const auto tv = allocate->buffer()->as(); + TORCH_INTERNAL_ASSERT(allocate->buffer()->isA()); + const auto tv = allocate->buffer()->as(); const auto& maybe_rfactor_domain = tv->domain()->hasRFactor() - ? tv->domain()->rfactorDomain() - : tv->domain()->rootDomain(); + ? tv->domain()->getRFactorDomain() + : tv->domain()->getRootDomain(); const auto nDims = std::count_if( maybe_rfactor_domain.begin(), maybe_rfactor_domain.end(), - [](const kir::IterDomain* id) { + [](const IterDomain* id) { return !id->isReduction() && - id->iterType() != IterType::BroadcastWithoutStride; + id->getIterType() != IterType::BroadcastWithoutStride; }); code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> " << varName(tv); @@ -177,7 +181,7 @@ class CudaKernelGenerator : private kir::IrVisitor { void genBody() { for (auto expr : kernel_->topLevelExprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } @@ -204,100 +208,93 @@ class CudaKernelGenerator : private kir::IrVisitor { return code_; } - std::string gen(const kir::Node* node) { + std::string gen(const Statement* stmt) { std::stringstream tmp_code; std::swap(tmp_code, code_); - auto replacement = replacement_map_.find(node); + auto replacement = replacement_map_.find(stmt); if (replacement != replacement_map_.end()) { - node = replacement->second; + stmt = replacement->second; } - node->accept(this); + OptOutConstDispatch::handle(stmt); std::swap(tmp_code, code_); return tmp_code.str(); } - // TODO(kir): consider automatic var naming - std::string varName(const kir::Val* val) { - std::string prefix = ""; - if (val->isA()) { - prefix = "T"; - } else { - prefix = typePrefix(val->dtype()); - } - - std::stringstream value_name; - if (val->name() != kInvalidStmName) { - value_name << prefix << val->name(); + std::string varName(const Val* val) { + std::stringstream name; + if (val->isA()) { + name << "T"; } else { - value_name << "k" << prefix << val->id(); + name << typePrefix(val->dtype()); } - return value_name.str(); + name << val->name(); + return name.str(); } - std::string genInline(const kir::Node* node) { + std::string genInline(const Statement* stmt) { const bool saved_inline = print_inline_; print_inline_ = true; - auto result = gen(node); + auto result = gen(stmt); print_inline_ = saved_inline; // NOLINTNEXTLINE(performance-no-automatic-move) return result; } - void visit(const kir::Predicate* node) final { - TORCH_INTERNAL_ASSERT(node->hasValue()); - code_ << gen(node->value()); + void handle(const kir::Predicate* pred) final { + TORCH_INTERNAL_ASSERT(pred->hasValue()); + code_ << gen(pred->value()); } - void visit(const kir::Bool* node) final { - const auto def = node->definition(); + void handle(const Bool* pred) final { + const auto def = pred->definition(); if (print_inline_ && def != nullptr) { code_ << "(" << gen(def) << ")"; - } else if (node->isConst()) { - code_ << (*node->value() ? "true" : "false"); + } else if (pred->isConst()) { + code_ << (*pred->value() ? "true" : "false"); } else { - code_ << varName(node); + code_ << varName(pred); } } - void visit(const kir::Double* node) final { - const auto def = node->definition(); + void handle(const Double* d) final { + const auto def = d->definition(); if (print_inline_ && def != nullptr) { code_ << "(" << gen(def) << ")"; - } else if (node->isConst()) { + } else if (d->isConst()) { const int digits = std::numeric_limits::max_digits10; - code_ << std::setprecision(digits) << *node->value(); + code_ << std::setprecision(digits) << *d->value(); } else { - code_ << varName(node); + code_ << varName(d); } } - void visit(const kir::Int* node) final { - const auto def = node->definition(); + void handle(const Int* i) final { + const auto def = i->definition(); if (print_inline_ && def != nullptr) { code_ << "(" << gen(def) << ")"; - } else if (node->isConst()) { - code_ << *node->value(); + } else if (i->isConst()) { + code_ << *i->value(); } else { - code_ << varName(node); + code_ << varName(i); } } - void visit(const kir::NamedScalar* node) final { + void handle(const NamedScalar* ns) final { // dim3 components are unsigned int. Cast to signed integer to // support negative indexing - if (node->getParallelIndex().has_value() || - node->getParallelDim().has_value()) { - code_ << "((nvfuser_index_t)" << node->name() << ")"; + if (ns->getParallelIndex().has_value() || + ns->getParallelDim().has_value()) { + code_ << "((nvfuser_index_t)" << ns->name() << ")"; } else { - code_ << node->name(); + code_ << ns->name(); } } - void visit(const kir::TensorIndex* node) final { - code_ << varName(node->view()) << "["; + void handle(const kir::TensorIndex* ti) final { + code_ << varName(ti->view()) << "["; bool first = true; - for (auto* ind : node->indices()) { + for (auto* ind : ti->indices()) { if (!ind->isZeroInt()) { if (!first) { code_ << " + "; @@ -314,29 +311,29 @@ class CudaKernelGenerator : private kir::IrVisitor { code_ << "]"; } - void visit(const kir::IterDomain* node) final { - TORCH_INTERNAL_ASSERT(false && "Unreachable"); + void handle(const IterDomain*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void visit(const kir::TensorDomain* node) final { - TORCH_INTERNAL_ASSERT(false && "Unreachable"); + void handle(const TensorDomain*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void visit(const kir::TensorView* tv) final { - TORCH_INTERNAL_ASSERT(false && "Unreachable"); + void handle(const TensorView*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void visit(const kir::UnaryOp* node) final { + void handle(const UnaryOp* uop) final { bool is_vector_op = false; size_t vector_word_size = 1; - if (vectorize_scope_ && node->out()->isA()) { - auto ti = node->out()->as(); + if (vectorize_scope_ && uop->out()->isA()) { + auto ti = uop->out()->as(); bool vectorize_op = false; bool misaligned_op = false; - for (auto id : ti->view()->fuserTv()->domain()->domain()) { + for (auto id : ti->view()->domain()->domain()) { if (!isParallelTypeVectorize(id->getParallelType())) { continue; } @@ -358,84 +355,84 @@ class CudaKernelGenerator : private kir::IrVisitor { if (vectorize_op) { TORCH_INTERNAL_ASSERT( - node->operation() == UnaryOpType::Set, + uop->getUnaryOpType() == UnaryOpType::Set, "Cannot vectorize operations that are not sets. ", "Use cache_before and cache_after to store/load with vectorized reads into buffers."); is_vector_op = true; } if (misaligned_op) { - is_vector_op = (node->operation() == UnaryOpType::Set); + is_vector_op = (uop->getUnaryOpType() == UnaryOpType::Set); } - if (is_vector_op && !node->in()->isScalar()) { + if (is_vector_op && !uop->in()->isScalar()) { TORCH_INTERNAL_ASSERT( - node->out()->dtype() == node->in()->dtype(), + uop->out()->dtype() == uop->in()->dtype(), "Vectorized store/load requires input and output datatypes match."); } } if (is_vector_op) { - if (node->in()->isScalar()) { + if (uop->in()->isScalar()) { indent() << "reinterpret_cast<" - << "Array<" << node->out()->dtype() << ", " << vector_word_size + << "Array<" << uop->out()->dtype() << ", " << vector_word_size << ">*>" - << "(&" << gen(node->out()) << ")->set(" << gen(node->in()) + << "(&" << gen(uop->out()) << ")->set(" << gen(uop->in()) << ");\n"; } else { indent() << "*reinterpret_cast<" - << "Array<" << node->out()->dtype() << ", " << vector_word_size + << "Array<" << uop->out()->dtype() << ", " << vector_word_size << ">*>" - << "(&" << gen(node->out()) << ")" + << "(&" << gen(uop->out()) << ")" << " = *reinterpret_cast<" - << "Array<" << node->in()->dtype() << ", " << vector_word_size + << "Array<" << uop->in()->dtype() << ", " << vector_word_size << ">*>" - << "(&" << gen(node->in()) << ");\n"; + << "(&" << gen(uop->in()) << ");\n"; } return; } - if (node->out()->isA()) { - const auto op_type = node->operation(); + if (uop->out()->isA()) { + const auto op_type = uop->getUnaryOpType(); if (auto op = inline_op_str(op_type)) { - indent() << gen(node->out()) << " = " << *op << genInline(node->in()) + indent() << gen(uop->out()) << " = " << *op << genInline(uop->in()) << ";\n"; } return; } if (!print_inline_) { - indent() << gen(node->out()); - if (!node->out()->isScalar() && !node->in()->isScalar()) { + indent() << gen(uop->out()); + if (!uop->out()->isScalar() && !uop->in()->isScalar()) { code_ << "\n"; indent() << kTab; } code_ << " = "; } - const auto op_type = node->operation(); + const auto op_type = uop->getUnaryOpType(); if (auto op = inline_op_str(op_type)) { if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { - code_ << stringifyBooleanOp(op_type) << gen(node->in()); + uop->out()->dtype() == DataType::Bool) { + code_ << stringifyBooleanOp(op_type) << gen(uop->in()); } else { - code_ << *op << gen(node->in()); + code_ << *op << gen(uop->in()); } } else { if (op_type == UnaryOpType::Cast) { const auto cast_str = - cast_func_str({node->in()->dtype(), node->out()->dtype()}); + cast_func_str({uop->in()->dtype(), uop->out()->dtype()}); TORCH_INTERNAL_ASSERT( cast_str.has_value(), "Invalid cast. Input type: ", - node->in()->dtype(), + uop->in()->dtype(), ", output type: ", - node->out()->dtype()); + uop->out()->dtype()); code_ << cast_str.value(); } else { code_ << op_type; if (needFloatSuffix(op_type) && - node->out()->dtype() == DataType::Float) { + uop->out()->dtype() == DataType::Float) { code_ << "f"; } } @@ -444,7 +441,7 @@ class CudaKernelGenerator : private kir::IrVisitor { if (op_type == UnaryOpType::RandLike) { code_ << "rnd"; } else { - code_ << gen(node->in()); + code_ << gen(uop->in()); } code_ << ")"; } @@ -456,7 +453,7 @@ class CudaKernelGenerator : private kir::IrVisitor { std::string genBinaryOp( BinaryOpType op_type, - kir::Val* out, + Val* out, const std::string& lhs, const std::string& rhs) { std::stringstream expr; @@ -485,7 +482,7 @@ class CudaKernelGenerator : private kir::IrVisitor { // If one argument is a tensorview and the other is a scalar, make sure we // cast the scalar to the tensorview type - std::string scalarCast(kir::Val* lhs, kir::Val* rhs) { + std::string scalarCast(Val* lhs, Val* rhs) { // If neither are scalars return if (!((lhs->isScalar() || rhs->isScalar()) && (lhs->isA() || rhs->isA()))) { @@ -520,18 +517,18 @@ class CudaKernelGenerator : private kir::IrVisitor { } // If possible, replace pow with mul. Return true when successful. - bool genPowerWithMul(const kir::BinaryOp* node) { - if (node->operation() != BinaryOpType::Pow) { + bool genPowerWithMul(const BinaryOp* bop) { + if (bop->getBinaryOpType() != BinaryOpType::Pow) { return false; } - auto rhs = node->rhs(); + auto rhs = bop->rhs(); c10::optional exponent; - if (auto val_int = dynamic_cast(rhs)) { + if (auto val_int = dynamic_cast(rhs)) { if (val_int->isConst()) { exponent = val_int->value().value(); } - } else if (auto val_float = dynamic_cast(rhs)) { + } else if (auto val_float = dynamic_cast(rhs)) { if (val_float->isConst()) { auto fp_exp = val_float->value().value(); double int_exp = 0; @@ -550,7 +547,7 @@ class CudaKernelGenerator : private kir::IrVisitor { return false; } - auto lhs = gen(node->lhs()); + auto lhs = gen(bop->lhs()); if (print_inline_) { code_ << lhs << " * " << lhs; @@ -558,8 +555,8 @@ class CudaKernelGenerator : private kir::IrVisitor { code_ << " * " << lhs; } } else { - indent() << gen(node->out()); - if (node->out()->isScalar()) { + indent() << gen(bop->out()); + if (bop->out()->isScalar()) { code_ << " = " << lhs << " * " << lhs; if (exponent.value() == 3) { code_ << " * " << lhs; @@ -579,24 +576,24 @@ class CudaKernelGenerator : private kir::IrVisitor { return true; } - void visit(const kir::BinaryOp* node) final { + void handle(const BinaryOp* bop) final { // Try replacing pow with mul - if (genPowerWithMul(node)) { + if (genPowerWithMul(bop)) { return; } - const auto op_type = node->operation(); + const auto op_type = bop->getBinaryOpType(); if (print_inline_) { // Inline expression: `lhs op rhs` code_ << genBinaryOp( - op_type, node->out(), gen(node->lhs()), gen(node->rhs())); + op_type, bop->out(), gen(bop->lhs()), gen(bop->rhs())); } else { - indent() << gen(node->out()); - if (node->out()->isScalar()) { + indent() << gen(bop->out()); + if (bop->out()->isScalar()) { // Single line: `out = lhs op rhs;` code_ << " = " << genBinaryOp( - op_type, node->out(), gen(node->lhs()), gen(node->rhs())); + op_type, bop->out(), gen(bop->lhs()), gen(bop->rhs())); } else { // Split TensorView expressions across multiple lines: // @@ -605,64 +602,64 @@ class CudaKernelGenerator : private kir::IrVisitor { // op rhs; // - auto cast = scalarCast(node->lhs(), node->rhs()); + auto cast = scalarCast(bop->lhs(), bop->rhs()); if (auto op = inline_op_str(op_type)) { code_ << "\n"; - indent() << kTab << "= " << (node->lhs()->isScalar() ? cast : "") - << gen(node->lhs()) << "\n"; + indent() << kTab << "= " << (bop->lhs()->isScalar() ? cast : "") + << gen(bop->lhs()) << "\n"; indent() << kTab; if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { + bop->out()->dtype() == DataType::Bool) { code_ << stringifyBooleanOp(op_type); } else { code_ << *op; } - code_ << " " << (node->rhs()->isScalar() ? cast : "") - << gen(node->rhs()); + code_ << " " << (bop->rhs()->isScalar() ? cast : "") + << gen(bop->rhs()); } else { - if (integer_op_str(op_type) && isIntegralType(node->out()->dtype())) { + if (integer_op_str(op_type) && isIntegralType(bop->out()->dtype())) { auto int_op = integer_op_str(op_type); code_ << " = " << *int_op << "(\n"; } else { std::stringstream op_str; op_str << op_type; if (needFloatSuffix(op_type) && - node->out()->dtype() == DataType::Float) { + bop->out()->dtype() == DataType::Float) { op_str << "f"; } code_ << " = " << op_str.str() << "(\n"; } - indent() << kTab << (node->lhs()->isScalar() ? cast : "") - << gen(node->lhs()) << ",\n"; - indent() << kTab << (node->rhs()->isScalar() ? cast : "") - << gen(node->rhs()) << ")"; + indent() << kTab << (bop->lhs()->isScalar() ? cast : "") + << gen(bop->lhs()) << ",\n"; + indent() << kTab << (bop->rhs()->isScalar() ? cast : "") + << gen(bop->rhs()) << ")"; } } code_ << ";\n"; } } - void visit(const kir::TernaryOp* node) final { + void handle(const TernaryOp* top) final { if (!print_inline_) { - indent() << gen(node->out()); - if (!node->out()->isScalar()) { + indent() << gen(top->out()); + if (!top->out()->isScalar()) { code_ << "\n"; indent() << kTab; } code_ << " = "; } - code_ << node->operation() << "(" << gen(node->in1()) << ", "; + code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", "; // Make sure the two operands of where has the same // type. Note that compiling "where(0.0f, 0.0)" fails because of // the overloading ambiguity. - if (node->operation() == TernaryOpType::Where) { - auto cast = scalarCast(node->in2(), node->in3()); - code_ << (node->in2()->isScalar() ? cast : "") << gen(node->in2()) << ", " - << (node->in3()->isScalar() ? cast : "") << gen(node->in3()) << ")"; + if (top->getTernaryOpType() == TernaryOpType::Where) { + auto cast = scalarCast(top->in2(), top->in3()); + code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << ", " + << (top->in3()->isScalar() ? cast : "") << gen(top->in3()) << ")"; } else { - code_ << gen(node->in2()) << ", " << gen(node->in3()) << ")"; + code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")"; } if (!print_inline_) { @@ -670,7 +667,7 @@ class CudaKernelGenerator : private kir::IrVisitor { } } - std::string genReductionOp(BinaryOpType op_type, kir::Val* out) { + std::string genReductionOp(BinaryOpType op_type, Val* out) { std::stringstream lambda; DataType data_type = out->dtype(); lambda << "[](" << data_type << " &a, " << data_type << " b) " @@ -678,47 +675,45 @@ class CudaKernelGenerator : private kir::IrVisitor { return lambda.str(); } - void visit(const kir::BroadcastOp* node) final { - TORCH_INTERNAL_ASSERT(node->out()->isA()); - const auto tensor_index = node->out()->as(); + void handle(const BroadcastOp* stmt) final { + TORCH_INTERNAL_ASSERT(stmt->out()->isA()); + const auto tensor_index = stmt->out()->as(); - const ParallelTypeBitmap domains = - kernel_->predicateMap().getParallelBroadcastDomains( - tensor_index->view()->fuserTv()); + const ParallelTypeBitmap parallel_types = + kernel_->summary().broadcast_parallel_types.at(stmt); - const bool thread_x = domains.get(ParallelType::TIDx); - const bool thread_y = domains.get(ParallelType::TIDy); - const bool thread_z = domains.get(ParallelType::TIDz); - const bool block_x = domains.get(ParallelType::BIDx); - const bool block_y = domains.get(ParallelType::BIDy); - const bool block_z = domains.get(ParallelType::BIDz); - - const bool grid_broadcast_needed = block_x || block_y || block_z; - const bool block_broadcast_needed = thread_x || thread_y || thread_z; + if (parallel_types.none()) { + // Not parallelized + indent() << gen(stmt->out()) << "\n"; + indent() << kTab << " = " << gen(stmt->in()) << ";\n"; + return; + } TORCH_INTERNAL_ASSERT( - !grid_broadcast_needed, - "Parallel broadcast across blocks not supported"); - - if (block_broadcast_needed) { - const auto data_type = node->out()->dtype(); - indent() << "broadcast::blockBroadcast<" << (thread_x ? "true" : "false") - << ", " << (thread_y ? "true" : "false") << ", " - << (thread_z ? "true" : "false") << ">(\n"; - indent() << kTab << gen(node->out()) << ",\n"; - indent() << kTab << gen(node->in()) << ",\n"; - indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; - TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - indent() << kTab << genInline(node->predicate()) << ");\n"; - } else { - indent() << gen(node->out()) << "\n"; - indent() << kTab << " = " << gen(node->in()) << ";\n"; + !parallel_types.hasBID(), + "Parallel broadcast across blocks should have been translated to a GridBroadcast IR node"); + + std::stringstream flags_str; + for (const ParallelType pt : kParallelTypeTIDs) { + const bool parallel_bcast = parallel_types.get(pt); + if (pt != kParallelTypeTIDs[0]) { + flags_str << ", "; + } + flags_str << (parallel_bcast ? "true" : "false"); } + + const auto data_type = stmt->out()->dtype(); + indent() << "broadcast::blockBroadcast<" << flags_str.str() << ">(\n"; + indent() << kTab << gen(stmt->out()) << ",\n"; + indent() << kTab << gen(stmt->in()) << ",\n"; + indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; + TORCH_INTERNAL_ASSERT( + stmt->predicate() != nullptr && stmt->predicate()->hasValue()); + indent() << kTab << genInline(stmt->predicate()) << ");\n"; } void genWarpReductionOp( - const kir::ReductionOp* node, + const ReductionOp* rop, const IterDomain* reduction_id) { bool is_single_warp = kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp; @@ -729,24 +724,25 @@ class CudaKernelGenerator : private kir::IrVisitor { } else { code_ << "(\n"; } - indent() << kTab << gen(node->out()) << ",\n"; - indent() << kTab << gen(node->in()) << ",\n"; - indent() << kTab << genReductionOp(node->operation(), node->out()) << ",\n"; + indent() << kTab << gen(rop->out()) << ",\n"; + indent() << kTab << gen(rop->in()) << ",\n"; + indent() << kTab << genReductionOp(rop->getReductionOpType(), rop->out()) + << ",\n"; indent() << kTab << "threadIdx,\n"; indent() << kTab << "blockDim,\n"; - indent() << kTab << "static_cast<" << node->out()->dtype() + indent() << kTab << "static_cast<" << rop->out()->dtype() << "*>(shared_mem),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - indent() << kTab << genInline(node->predicate()) << ",\n"; - indent() << kTab << node->out()->dtype() << "(" << genInline(node->init()) + rop->predicate() != nullptr && rop->predicate()->hasValue()); + indent() << kTab << genInline(rop->predicate()) << ",\n"; + indent() << kTab << rop->out()->dtype() << "(" << genInline(rop->init()) << "));\n"; } - void visit(const kir::ReductionOp* node) final { - TORCH_INTERNAL_ASSERT(node->out()->isA()); + void handle(const ReductionOp* rop) final { + TORCH_INTERNAL_ASSERT(rop->out()->isA()); - const auto out = node->out()->as(); + const auto out = rop->out()->as(); const auto domain = out->view()->domain(); const bool has_block_reduce = domain->hasBlockReduction(); @@ -754,18 +750,18 @@ class CudaKernelGenerator : private kir::IrVisitor { if (!has_block_reduce && !has_grid_reduce) { const auto gen_out = gen(out); - const auto op_type = node->operation(); + const auto op_type = rop->getReductionOpType(); indent() << gen_out << " = " - << genBinaryOp(op_type, out, gen_out, gen(node->in())) << ";\n"; + << genBinaryOp(op_type, out, gen_out, gen(rop->in())) << ";\n"; return; } - if (auto reduction_id = ir_utils::getMaybeWarpReductionDim(node)) { - genWarpReductionOp(node, reduction_id.value()); + if (auto reduction_id = ir_utils::getMaybeWarpReductionDim(rop)) { + genWarpReductionOp(rop, reduction_id.value()); return; } - const auto par_domains = ir_utils::getParallelDomains(node->out()); + const auto par_domains = ir_utils::getParallelDomains(rop->out()); // Get parallel reduction domains const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end() && @@ -777,14 +773,14 @@ class CudaKernelGenerator : private kir::IrVisitor { par_domains.find(ParallelType::TIDz) != par_domains.end() && par_domains.at(ParallelType::TIDz)->isReduction(); - const auto data_type = node->out()->dtype(); - const auto op_type = node->operation(); + const auto data_type = rop->out()->dtype(); + const auto op_type = rop->getReductionOpType(); if (has_block_reduce) { if (has_grid_reduce) { indent() << data_type << " " << "block_result_" << block_reduce_name_ << "=" - << gen(node->init()) << ";\n"; + << gen(rop->init()) << ";\n"; } indent() << "blockReduce<" << (tidx ? "true" : "false") << ", " << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false") @@ -792,44 +788,43 @@ class CudaKernelGenerator : private kir::IrVisitor { if (has_grid_reduce) { indent() << kTab << "block_result_" << block_reduce_name_ << ",\n"; } else { - indent() << kTab << gen(node->out()) << ",\n"; + indent() << kTab << gen(rop->out()) << ",\n"; } - indent() << kTab << gen(node->in()) << ",\n"; - indent() << kTab << genReductionOp(op_type, node->out()) << ",\n"; + indent() << kTab << gen(rop->in()) << ",\n"; + indent() << kTab << genReductionOp(op_type, rop->out()) << ",\n"; indent() << kTab << "threadIdx,\n"; indent() << kTab << "blockDim,\n"; indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + rop->predicate() != nullptr && rop->predicate()->hasValue()); + auto read_pred = genInline(rop->predicate()); indent() << kTab << read_pred << ",\n"; // Pass the write predicate if available and different from the // default predicate. The blockReduce runtime function uses the // default predicate for both read and write when only the // default one is given. - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (rop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(rop->writePredicate()->hasValue()); + auto write_pred = genInline(rop->writePredicate()); indent() << kTab << write_pred << ",\n"; } - indent() << kTab << data_type << "(" << genInline(node->init()) - << "));\n"; + indent() << kTab << data_type << "(" << genInline(rop->init()) << "));\n"; } } - void visit(const kir::WelfordOp* node) final { - TORCH_INTERNAL_ASSERT(node->out()->isA()); + void handle(const WelfordOp* wop) final { + TORCH_INTERNAL_ASSERT(wop->out()->isA()); - const auto out = node->out()->as(); + const auto out = wop->out()->as(); const auto domain = out->view()->domain(); - const auto out_var = node->outVar(); - const auto out_avg = node->outAvg(); - const auto out_N = node->outN(); + const auto out_var = wop->outVar(); + const auto out_avg = wop->outAvg(); + const auto out_N = wop->outN(); - const auto in_var = node->inVar(); - const auto in_avg = node->inAvg(); - const auto in_N = node->inN(); + const auto in_var = wop->inVar(); + const auto in_avg = wop->inAvg(); + const auto in_N = wop->inN(); const bool has_block_reduce = domain->hasBlockReduction(); const bool has_grid_reduce = domain->hasGridReduction(); @@ -852,7 +847,7 @@ class CudaKernelGenerator : private kir::IrVisitor { return; } - const auto par_domains = ir_utils::getParallelDomains(node->out()); + const auto par_domains = ir_utils::getParallelDomains(wop->out()); // Get parallel reduction domains const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end() && @@ -864,20 +859,20 @@ class CudaKernelGenerator : private kir::IrVisitor { par_domains.find(ParallelType::TIDz) != par_domains.end() && par_domains.at(ParallelType::TIDz)->isReduction(); - const auto data_type = node->out()->dtype(); + const auto data_type = wop->out()->dtype(); if (has_block_reduce) { if (has_grid_reduce) { // allocate block result indent() << data_type << " " << "block_result_avg_" << block_reduce_name_ << " = " - << gen(node->initAvg()) << ";\n"; + << gen(wop->initAvg()) << ";\n"; indent() << data_type << " " << "block_result_var_" << block_reduce_name_ << " = " - << gen(node->initVar()) << ";\n"; + << gen(wop->initVar()) << ";\n"; indent() << DataType::Int << " " << "block_result_n_" << block_reduce_name_ << " = " - << gen(node->initN()) << ";\n"; + << gen(wop->initN()) << ";\n"; } indent() << "blockWelford<" << (tidx ? "true" : "false") << ", " << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false") @@ -887,9 +882,9 @@ class CudaKernelGenerator : private kir::IrVisitor { << kTab << "block_result_var_" << block_reduce_name_ << ",\n" << kTab << "block_result_n_" << block_reduce_name_ << ",\n"; } else { - indent() << kTab << gen(node->outAvg()) << ",\n"; - indent() << kTab << gen(node->outVar()) << ",\n"; - indent() << kTab << gen(node->outN()) << ",\n"; + indent() << kTab << gen(wop->outAvg()) << ",\n"; + indent() << kTab << gen(wop->outVar()) << ",\n"; + indent() << kTab << gen(wop->outN()) << ",\n"; } indent() << " " << gen(in_avg) << ",\n"; if (in_var) { @@ -907,14 +902,14 @@ class CudaKernelGenerator : private kir::IrVisitor { << "*>(shared_mem_var),\n"; indent() << kTab << "reinterpret_cast<" << DataType::Int << "*>(shared_mem_n),\n"; - TORCH_INTERNAL_ASSERT(node->predicate() != nullptr); + TORCH_INTERNAL_ASSERT(wop->predicate() != nullptr); TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + wop->predicate() != nullptr && wop->predicate()->hasValue()); + auto read_pred = genInline(wop->predicate()); indent() << kTab << read_pred << ",\n"; - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (wop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(wop->writePredicate()->hasValue()); + auto write_pred = genInline(wop->writePredicate()); indent() << kTab << write_pred << ",\n"; } indent() << kTab << data_type << "(0));\n"; @@ -954,8 +949,8 @@ class CudaKernelGenerator : private kir::IrVisitor { return flags.str(); } - void visit(const kir::GridReduction* node) final { - const auto rop = node->reduction_op(); + void handle(const kir::GridReduction* grop) final { + const auto rop = grop->reduction_op(); TORCH_INTERNAL_ASSERT(rop->out()->isA()); const auto out = rop->out()->as(); @@ -963,19 +958,17 @@ class CudaKernelGenerator : private kir::IrVisitor { TORCH_INTERNAL_ASSERT(domain->hasGridReduction()); const auto data_type = rop->out()->dtype(); - const auto op_type = rop->operation(); + const auto op_type = rop->getReductionOpType(); TORCH_INTERNAL_ASSERT( - node->reduction_buffer()->buffer()->isA()); - TORCH_INTERNAL_ASSERT( - node->sync_buffer()->buffer()->isA()); + grop->reduction_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA()); const auto work_buffer = - node->reduction_buffer()->buffer()->as(); - const auto sync_buffer = - node->sync_buffer()->buffer()->as(); + grop->reduction_buffer()->buffer()->as(); + const auto sync_buffer = grop->sync_buffer()->buffer()->as(); const std::string flags_str = - generateGridReduceTemplateFlags(rop, node->threadPredicate()); + generateGridReduceTemplateFlags(rop, grop->threadPredicate()); const bool persistent_sync = kernel_->summary().has_cooperative_grid_reduction; @@ -996,44 +989,46 @@ class CudaKernelGenerator : private kir::IrVisitor { indent() << kTab << varName(sync_buffer) << ",\n"; indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + grop->predicate() != nullptr && grop->predicate()->hasValue()); + auto read_pred = genInline(grop->predicate()); indent() << kTab << read_pred << ",\n"; - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (grop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue()); + auto write_pred = genInline(grop->writePredicate()); indent() << kTab << write_pred << ",\n"; } else { indent() << kTab << read_pred << ",\n"; } indent() << kTab << data_type << "(" - << genInline(node->reduction_op()->init()) << "));\n"; + << genInline(grop->reduction_op()->init()) << "));\n"; } - void visit(const kir::GridBroadcast* node) final { - const auto bop = node->broadcast_op(); + void handle(const kir::GridBroadcast* grop) final { + const auto bop = grop->broadcast_op(); TORCH_INTERNAL_ASSERT(bop->out()->isA()); + const ParallelTypeBitmap parallel_types = + kernel_->summary().broadcast_parallel_types.at(bop); + + TORCH_INTERNAL_ASSERT( + parallel_types.hasBID(), + "GridBroadcast needs to be used with a broadcast op that is parallelized with the BID parallel types"); + const auto out = bop->out()->as(); const auto domain = out->view()->domain(); - TORCH_INTERNAL_ASSERT(domain->hasGridBroadcast()); const auto data_type = bop->out()->dtype(); TORCH_INTERNAL_ASSERT( - node->broadcast_buffer()->buffer()->isA()); - TORCH_INTERNAL_ASSERT( - node->sync_buffer()->buffer()->isA()); + grop->broadcast_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA()); const auto work_buffer = - node->broadcast_buffer()->buffer()->as(); - const auto sync_buffer = - node->sync_buffer()->buffer()->as(); + grop->broadcast_buffer()->buffer()->as(); + const auto sync_buffer = grop->sync_buffer()->buffer()->as(); - const auto par_domains = ir_utils::getParallelDomains(out); std::stringstream flags_str; for (const ParallelType pt : kParallelTypeThreads) { - const bool parallel_bcast = par_domains.find(pt) != par_domains.end() && - par_domains.at(pt)->isBroadcast(); + const bool parallel_bcast = parallel_types.get(pt); if (pt != kParallelTypeThreads[0]) { flags_str << ", "; } @@ -1041,7 +1036,7 @@ class CudaKernelGenerator : private kir::IrVisitor { } // Since block-level broadcast has not necessarily been performed before - // this function call, so grid broadcast may be broadcasting across both + // this function call, so grid broadcast may be broadcasting across both // the grid and the block level. indent() << "grid_broadcast::broadcast<" << flags_str.str() << ">(\n"; indent() << kTab << gen(bop->out()) << ",\n"; @@ -1049,12 +1044,12 @@ class CudaKernelGenerator : private kir::IrVisitor { indent() << kTab << "&" << varName(work_buffer) << "[0],\n"; indent() << kTab << varName(sync_buffer) << ",\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - indent() << kTab << genInline(node->predicate()) << ");\n"; + grop->predicate() != nullptr && grop->predicate()->hasValue()); + indent() << kTab << genInline(grop->predicate()) << ");\n"; } - void visit(const kir::GridWelford* node) final { - const auto wop = node->welford_op(); + void handle(const kir::GridWelford* gwop) final { + const auto wop = gwop->welford_op(); TORCH_INTERNAL_ASSERT(wop->outAvg()->isA()); const auto out = wop->out()->as(); @@ -1063,21 +1058,19 @@ class CudaKernelGenerator : private kir::IrVisitor { const auto data_type = out->dtype(); - TORCH_INTERNAL_ASSERT(node->var_buffer()->buffer()->isA()); - TORCH_INTERNAL_ASSERT( - node->sync_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(gwop->var_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(gwop->sync_buffer()->buffer()->isA()); - const auto avg_buffer = node->avg_buffer()->buffer()->as(); - const auto var_buffer = node->var_buffer()->buffer()->as(); - const auto n_buffer = node->N_buffer()->buffer()->as(); - const auto sync_buffer = - node->sync_buffer()->buffer()->as(); + const auto avg_buffer = gwop->avg_buffer()->buffer()->as(); + const auto var_buffer = gwop->var_buffer()->buffer()->as(); + const auto n_buffer = gwop->N_buffer()->buffer()->as(); + const auto sync_buffer = gwop->sync_buffer()->buffer()->as(); const bool persistent_sync = kernel_->summary().has_cooperative_grid_reduction; const std::string flags_str = - generateGridReduceTemplateFlags(wop, node->threadPredicate()); + generateGridReduceTemplateFlags(wop, gwop->threadPredicate()); // Since block-level reduction is already done, those dimensions // with tidx/y/z being true do not participate in the grid reduction. @@ -1112,12 +1105,12 @@ class CudaKernelGenerator : private kir::IrVisitor { indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype() << "*>(shared_mem_n),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + gwop->predicate() != nullptr && gwop->predicate()->hasValue()); + auto read_pred = genInline(gwop->predicate()); indent() << kTab << read_pred << ",\n"; - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (gwop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue()); + auto write_pred = genInline(gwop->writePredicate()); indent() << kTab << write_pred << ",\n"; } else { indent() << kTab << read_pred << ",\n"; @@ -1128,27 +1121,26 @@ class CudaKernelGenerator : private kir::IrVisitor { void handleScope(const kir::Scope& scope) { for (auto expr : scope.exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } - void visit(const kir::ForLoop* node) final { - // TODO(kir): handle this during lowering - if (node->iter_domain()->isBroadcast()) { - handleScope(node->body()); + void handle(const kir::ForLoop* loop) final { + if (loop->iter_domain()->isBroadcast()) { + handleScope(loop->body()); return; - } else if (node->vectorize()) { - vectorize_scope_ = node->vectorize(); - handleScope(node->body()); + } else if (loop->vectorize()) { + vectorize_scope_ = loop->vectorize(); + handleScope(loop->body()); vectorize_scope_ = false; return; - } else if (node->iter_domain()->isStride()) { + } else if (loop->iter_domain()->isStride()) { // A stride domain only executes the loop body with the loop // index being zero. indent() << "constexpr " << "nvfuser_index_t" - << " " << gen(node->index()) << " = 0;\n"; - handleScope(node->body()); + << " " << gen(loop->index()) << " = 0;\n"; + handleScope(loop->body()); return; } @@ -1168,56 +1160,82 @@ class CudaKernelGenerator : private kir::IrVisitor { // necessary since the loop stop value just needs to be <= the // IterDomain extent. However, at this point, this conservative // analysis seems sufficient. - if (node->stop() == node->iter_domain()->extent() && - node->iter_domain()->isThread()) { + if (loop->stop() == loop->iter_domain()->extent() && + loop->iter_domain()->isThread()) { // Register a replacement of references to the loop index with // the loop start value. - replacement_map_.insert({node->index(), node->start()}); - handleScope(node->body()); - replacement_map_.erase(node->index()); + replacement_map_.insert({loop->index(), loop->start()}); + handleScope(loop->body()); + replacement_map_.erase(loop->index()); return; } - if (node->start()->isZeroInt() && node->stop()->isOneInt()) { + if (loop->start()->isZeroInt() && loop->stop()->isOneInt()) { indent() << "constexpr " << "nvfuser_index_t" - << " " << gen(node->index()) << " = 0;\n"; - handleScope(node->body()); + << " " << gen(loop->index()) << " = 0;\n"; + handleScope(loop->body()); + return; + } else if ( + // Special case handling for a pattern where start == end - 1. + loop->start()->definition() != nullptr && + loop->start()->definition()->isA() && + loop->start()->definition()->as()->getBinaryOpType() == + BinaryOpType::Sub && + loop->start()->definition()->as()->lhs() == loop->stop() && + loop->start()->definition()->as()->rhs()->isOneInt()) { + indent() << "const " + << "nvfuser_index_t" + << " " << gen(loop->index()) << " = " << genInline(loop->start()) + << ";\n"; + handleScope(loop->body()); return; } - const auto gen_index = gen(node->index()); - const auto gen_start = genInline(node->start()); - const auto gen_stop = genInline(node->stop()); - const auto gen_step = genInline(node->step()); + const auto gen_index = gen(loop->index()); + const auto gen_start = genInline(loop->start()); + const auto gen_stop = genInline(loop->stop()); + const auto gen_step = genInline(loop->step()); std::stringstream step_code; - if (node->step()->isOneInt()) { + if (loop->step()->isOneInt()) { step_code << "++" << gen_index; } else { step_code << gen_index << " += " << gen_step; } - if (node->isUnrolled()) { + if (loop->isUnrolled()) { indent() << "#pragma unroll\n"; } else { indent() << "#pragma unroll 1\n"; } - indent() << "for(nvfuser_index_t " << gen_index << " = " << gen_start - << "; " << gen_index << " < " << gen_stop << "; " - << step_code.str() << ") "; + + indent() << "for(nvfuser_index_t " << gen_index; + if (loop->iter_domain()->isParallelized()) { + code_ << " = " << gen_start << "; "; + } else { + // Do not start at the start of the ID when not parallelized. Instead, + // start at 0. Predicates will protect buffers between 0 and ID->start(), + // however if we started at ID->start and extent == ID->start, we could + // have a "degenerate" loop (loop with no iterations). It may not be an + // issue to have a 0-sized loop, but all potential consequences haven't + // been covered. One example is WAR analysis which could incorrectly think + // a barrier inside a 0-sized loop actually provides protection. + code_ << " = 0; "; + } + code_ << gen_index << " < " << gen_stop << "; " << step_code.str() << ") "; startBlock(true); - handleScope(node->body()); + handleScope(loop->body()); endBlock(); } - void visit(const kir::IfThenElse* node) final { - auto conditional = node->predicate()->value(); + void handle(const kir::IfThenElse* ite) final { + auto conditional = ite->predicate()->value(); if (conditional->isConst()) { // If the conditional is a constant, then the IfThenElse is not required if (conditional->value().value()) { - handleScope(node->thenBody()); + handleScope(ite->thenBody()); } else { - handleScope(node->elseBody()); + handleScope(ite->elseBody()); } return; } @@ -1226,41 +1244,40 @@ class CudaKernelGenerator : private kir::IrVisitor { // "then" block startBlock(true); - handleScope(node->thenBody()); + handleScope(ite->thenBody()); // "else" block (optional) - if (node->hasElse()) { + if (ite->hasElse()) { endBlock(" else "); startBlock(true); - handleScope(node->elseBody()); + handleScope(ite->elseBody()); } endBlock(); } - // TODO(kir): fold initialization into Allocate - void visit(const kir::Allocate* node) final { - const auto buffer_dtype = node->buffer()->dtype(); + void handle(const kir::Allocate* alloc) final { + const auto buffer_dtype = alloc->buffer()->dtype(); - if (!node->buffer()->isA()) { - indent() << buffer_dtype << " " << gen(node->buffer()) << ";\n"; + if (!alloc->buffer()->isA()) { + indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n"; return; } - const auto tv = node->buffer()->as(); + const auto tv = alloc->buffer()->as(); - const auto size = node->size(); + const auto size = alloc->size(); TORCH_INTERNAL_ASSERT(size != nullptr); - if (node->alias() != nullptr) { - // Allocate alias another Allocate node - const auto alias_tv = node->alias()->buffer()->as(); - indent() << "// Alias Allocation - " << node->memoryType() << "\n"; + if (alloc->alias() != nullptr) { + // Allocate alias another Allocate stmt + const auto alias_tv = alloc->alias()->buffer()->as(); + indent() << "// Alias Allocation - " << alloc->memoryType() << "\n"; indent() << buffer_dtype << "* " << varName(tv) << " = " << varName(alias_tv) << ";\n"; } else { // Standard Memory Allocation - switch (tv->memoryType()) { + switch (tv->getMemoryType()) { case MemoryType::Global: indent() << "// Allocate global tensor " << varName(tv) << "\n"; break; @@ -1292,7 +1309,7 @@ class CudaKernelGenerator : private kir::IrVisitor { } } - void visit(const kir::Sync* node) final { + void handle(const kir::Sync*) final { // Use a custom synchronization method if enabled if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { indent() << "block_sync::sync();\n"; @@ -1301,11 +1318,11 @@ class CudaKernelGenerator : private kir::IrVisitor { } } - void visit(const kir::InitMagicZero* node) final { + void handle(const kir::InitMagicZero*) final { indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; } - void visit(const kir::UpdateMagicZero* node) final { + void handle(const kir::UpdateMagicZero*) final { indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; } @@ -1314,15 +1331,13 @@ class CudaKernelGenerator : private kir::IrVisitor { const kir::Kernel* kernel_; int block_nest_level_ = 0; int block_reduce_name_ = 0; - - // TODO(kir): replace with explicit assignment statements bool print_inline_ = false; // Mark when we are inside of a vectorized for-loop bool vectorize_scope_ = false; //! Holds active replacement mappings during codegen - std::unordered_map replacement_map_; + std::unordered_map replacement_map_; }; } // namespace diff --git a/torch/csrc/jit/codegen/cuda/codegen.h b/torch/csrc/jit/codegen/cuda/codegen.h index 2ffbb872155..31e4fb70736 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.h +++ b/torch/csrc/jit/codegen/cuda/codegen.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 45f744d7e2f..f51e0fe1bc9 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -59,14 +59,8 @@ bool validateDomain(TensorView* tv, TensorDomain* new_td) { unsigned int getReplayablePosPasC( TensorView* producer, TensorView* consumer, - const ComputeAtRootDomainMap& root_map_, + const std::unordered_set& unmappable_producer_dims, ComputeAtMode mode) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline reduction structures. - auto mappable_roots = - root_map_.getMappableDims(producer->domain(), consumer->domain()); - // Check if any consumer dimensions are marked as vectorize as producer can // not be inlined to vectorized dimensions in consumer. auto c_dom = consumer->domain()->domain(); @@ -124,9 +118,14 @@ unsigned int getReplayablePosPasC( if (std::any_of( consumer_root_dim_ids.begin(), consumer_root_dim_ids.end(), - [&mappable_roots, &c2p_root_map](IterDomain* root_id) { - return mappable_roots.find(root_id) == mappable_roots.end() && - c2p_root_map.find(root_id) != c2p_root_map.end(); + [&unmappable_producer_dims, &c2p_root_map](IterDomain* c_root_id) { + auto p_root_id_it = c2p_root_map.find(c_root_id); + if (p_root_id_it == c2p_root_map.end()) { + return false; + } + auto p_id = p_root_id_it->second; + return unmappable_producer_dims.find(p_id) != + unmappable_producer_dims.end(); })) { continue; } @@ -146,14 +145,8 @@ unsigned int getReplayablePosPasC( unsigned int getReplayablePosCasP( TensorView* consumer, TensorView* producer, - const ComputeAtRootDomainMap& root_map_, + const std::unordered_set& unmappable_producer_dims, ComputeAtMode mode) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline reduction structures. - auto mappable_roots = - root_map_.getMappableDims(producer->domain(), consumer->domain()); - auto p_dom = producer->domain()->domain(); auto first_reduction = std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { @@ -208,10 +201,11 @@ unsigned int getReplayablePosCasP( if (std::any_of( producer->getMaybeRFactorDomain().begin(), producer->getMaybeRFactorDomain().end(), - [&mappable_roots, &all_vals](IterDomain* root_id) { - return std::find(all_vals.begin(), all_vals.end(), root_id) != + [&unmappable_producer_dims, &all_vals](IterDomain* p_root_id) { + return std::find(all_vals.begin(), all_vals.end(), p_root_id) != all_vals.end() && - mappable_roots.find(root_id) == mappable_roots.end(); + unmappable_producer_dims.find(p_root_id) != + unmappable_producer_dims.end(); })) { continue; } @@ -446,7 +440,8 @@ unsigned int ComputeAt::backwardComputeAt_impl( FUSER_PERF_SCOPE("backwardComputeAt_impl"); auto max_consumer_compute_at_pos = - getReplayablePosPasC(producer, consumer, root_map_, mode_); + getReplayablePosPasC(producer, consumer, unmappable_dims_, mode_); + if (mode_ == ComputeAtMode::BestEffort) { consumer_compute_at_pos = std::min(consumer_compute_at_pos, max_consumer_compute_at_pos); @@ -517,7 +512,7 @@ unsigned int ComputeAt::forwardComputeAt_impl( FUSER_PERF_SCOPE("forwardComputeAt_impl"); auto max_producer_compute_at_pos = - getReplayablePosCasP(consumer, producer, root_map_, mode_); + getReplayablePosCasP(consumer, producer, unmappable_dims_, mode_); if (mode_ == ComputeAtMode::BestEffort) { producer_compute_at_pos = @@ -865,6 +860,25 @@ void ComputeAt::runPass() { } } +void ComputeAt::buildUnmappableDims() { + auto all_tvs = ir_utils::allTvs(producer_->fusion()); + for (auto tv : all_tvs) { + auto consumers = ir_utils::consumerTvsOf(tv); + for (auto consumer : consumers) { + // Grab dimensions in producer and consumer that are mappable to eachother + // based on the computeAtRootDomainMap. This will tell us which dimensions + // can be inlined based on avoiding trying to inline reduction structures. + auto mappable_roots = + root_map_.getMappableDims(tv->domain(), consumer->domain()); + for (auto tv_root_id : tv->getMaybeRFactorDomain()) { + if (mappable_roots.find(tv_root_id) == mappable_roots.end()) { + unmappable_dims_.emplace(tv_root_id); + } + } + } + } +} + ComputeAt::ComputeAt( TensorView* _producer, TensorView* _consumer, @@ -903,6 +917,8 @@ ComputeAt::ComputeAt( setCommonConsumer(); root_map_.build(); + + buildUnmappableDims(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 391225218db..75fca5705ed 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -2,11 +2,12 @@ #include +#include #include -#include #include #include +#include #include namespace torch { @@ -68,6 +69,10 @@ class ComputeAt { // call. void setCommonConsumer(); + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + // Propagate backward from consumer to producer, check if it increase // computeAt position on tensors, if so take it! void traverseBackward(); @@ -106,6 +111,9 @@ class ComputeAt { // Producer use chains set in, used in a few spots. std::deque> producer_use_chains_; + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims_; + ComputeAt( TensorView* _producer, TensorView* _consumer, diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 6671fc37546..f46a7495163 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include @@ -488,71 +487,6 @@ void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) { } } } - - if (gpu_lower != nullptr) { - convertToKir(fusion, gpu_lower); - } -} - -void ComputeAtMap::convertToKir(Fusion* fusion, GpuLower* gpu_lower) { - TORCH_INTERNAL_ASSERT(fusion != nullptr); - TORCH_INTERNAL_ASSERT(gpu_lower != nullptr); - - has_lowered_kir_ = true; - - std::unordered_map< - std::shared_ptr>, - std::shared_ptr>> - disjoint_set_2_kir; - - for (const auto& disjoint_iter_set : disjoint_iter_set_maps_) { - auto fusion_set = disjoint_iter_set.second; - auto kir_set_it = disjoint_set_2_kir.find(fusion_set); - std::shared_ptr> kir_set; - if (kir_set_it == disjoint_set_2_kir.end()) { - kir_set = std::make_shared>(); - std::transform( - fusion_set->begin(), - fusion_set->end(), - std::inserter(*kir_set, kir_set->begin()), - [&gpu_lower](IterDomain* id) { - return gpu_lower->lowerValue(id)->as(); - }); - disjoint_set_2_kir.emplace(std::make_pair(fusion_set, kir_set)); - } else { - kir_set = kir_set_it->second; - } - kir_disjoint_iter_set_maps_.emplace(std::make_pair( - gpu_lower->lowerValue(disjoint_iter_set.first)->as(), - kir_set)); - } - - for (auto entry : concrete_id_map_) { - kir_concrete_id_map_.emplace(std::make_pair( - gpu_lower->lowerValue(entry.first)->as(), - gpu_lower->lowerValue(entry.second)->as())); - } - - for (const auto& entry : disjoint_iter_set_maps_) { - kir_2_fusion_[gpu_lower->lowerValue(entry.first)->as()] = - entry.first; - } - - // Make sure we have all IterDomains that could be used to generate a ForLoop - for (auto expr : fusion->exprs()) { - if (!expr->outputs()[0]->isA()) { - continue; - } - - auto tv_outputs = ir_utils::filterByType(expr->outputs()); - - for (auto out : tv_outputs) { - for (auto entry : out->domain()->domain()) { - kir_2_fusion_[gpu_lower->lowerValue(entry)->as()] = - entry; - } - } - } } bool ComputeAtMap::areMapped(IterDomain* id0, IterDomain* id1) const { @@ -568,20 +502,6 @@ bool ComputeAtMap::areMapped(IterDomain* id0, IterDomain* id1) const { return (set0_it->second.get() == set1_it->second.get()); } -bool ComputeAtMap::areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const { - assertLowered(has_lowered_kir_); - if (id0 == id1) { - return true; - } - auto set0_it = kir_disjoint_iter_set_maps_.find(id0); - auto set1_it = kir_disjoint_iter_set_maps_.find(id1); - if (set0_it == kir_disjoint_iter_set_maps_.end() || - set1_it == kir_disjoint_iter_set_maps_.end()) { - return false; - } - return (set0_it->second.get() == set1_it->second.get()); -} - IterDomain* ComputeAtMap::getConcreteMappedID(IterDomain* id) const { auto it = concrete_id_map_.find(id); if (it != concrete_id_map_.end()) { @@ -590,25 +510,6 @@ IterDomain* ComputeAtMap::getConcreteMappedID(IterDomain* id) const { return id; } -kir::IterDomain* ComputeAtMap::getConcreteMappedID(kir::IterDomain* id) const { - assertLowered(has_lowered_kir_); - auto it = kir_concrete_id_map_.find(id); - if (it != kir_concrete_id_map_.end()) { - return it->second; - } - return id; -} - -IterDomain* ComputeAtMap::toFusion(kir::IterDomain* kir) const { - assertLowered(has_lowered_kir_); - auto kir_2_fusion_it = kir_2_fusion_.find(kir); - TORCH_INTERNAL_ASSERT( - kir_2_fusion_it != kir_2_fusion_.end(), - "Kernel ir is not guarneteed to be reversible into fusion ir, could not find fusion entry. ", - kir::toString(kir, false)); - return kir_2_fusion_it->second; -} - std::string ComputeAtMap::toString() const { std::stringstream ss; diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index b2b70f8997d..8b7f9acd8fe 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -67,34 +67,18 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! same loop nest in the lowered code bool areMapped(IterDomain* id0, IterDomain* id1) const; - bool areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const; - //! Returns an iter domain that is the maximum expanded size of all iter //! domains the one provided maps to. Useful for opening loops to the correct //! iteration size. Not guarenteed to return the same ID every call, but is //! guarenteed to return iter domains in the same disjoint set. IterDomain* getConcreteMappedID(IterDomain* id) const; - kir::IterDomain* getConcreteMappedID(kir::IterDomain* id) const; - - // TODO: Would be great if we didn't need this, but we have nice functionality - // in iter_visitor that isn't moved over. Use of this is limited to indexing - // and this should definitely be removed by building out kernel ir to have - // better parity with fusion ir. - IterDomain* toFusion(kir::IterDomain* kir) const; - // Prints mapping information via Fusion IR std::string toString() const; private: - bool has_lowered_kir_ = false; - void mapIds(IterDomain* id0, IterDomain* id1); - //! Convert everything to lowered structures (kernel ir), as we will use - //! this class frequently during lowering. - void convertToKir(Fusion* fusion, GpuLower* gpu_lower); - private: MappingMode mapping_mode_ = MappingMode::LOOP; @@ -109,11 +93,6 @@ class TORCH_CUDA_CU_API ComputeAtMap { std::unordered_map>> disjoint_iter_set_maps_; - std::unordered_map< - kir::IterDomain*, - std::shared_ptr>> - kir_disjoint_iter_set_maps_; - // Keep a list of disjoint_iter_sets that's deterministic to iterate over std::deque>> disjoint_iter_sets_; @@ -125,12 +104,6 @@ class TORCH_CUDA_CU_API ComputeAtMap { // For each IterDomain set we will track how many concrete root domains were // used to generate the IterDomain std::unordered_map concrete_id_map_; - - std::unordered_map kir_concrete_id_map_; - - // Map kir::IterDomain* back to the fusion IR IterDomain*. - // TODO: Would be great if we didn't need this. - std::unordered_map kir_2_fusion_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index cea8b24e7ff..1702de93bdd 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -37,7 +37,7 @@ T* ptr(T* obj) { * } * * And therefore dispatch should never call: - * ptr(mutator)->handle(this->as()); + * ptr(mutator)->mutate(this->as()); */ template @@ -58,6 +58,10 @@ void Val::dispatch(T handler, Val* val) { break; } break; + case ValType::NamedScalar: + ptr(handler)->handle(val->as()); + return; + case ValType::IterDomain: ptr(handler)->handle(val->as()); return; @@ -67,8 +71,11 @@ void Val::dispatch(T handler, Val* val) { case ValType::TensorView: ptr(handler)->handle(val->as()); return; - case ValType::NamedScalar: - ptr(handler)->handle(val->as()); + case ValType::Predicate: + ptr(handler)->handle(val->as()); + return; + case ValType::TensorIndex: + ptr(handler)->handle(val->as()); return; default: break; @@ -79,12 +86,6 @@ void Val::dispatch(T handler, Val* val) { template void Expr::dispatch(T handler, Expr* expr) { switch (*(expr->getExprType())) { - case ExprType::Split: - ptr(handler)->handle(expr->as()); - return; - case ExprType::Merge: - ptr(handler)->handle(expr->as()); - return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -103,6 +104,13 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Split: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Merge: + ptr(handler)->handle(expr->as()); + return; case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; @@ -115,6 +123,34 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::ViewOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Allocate: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Sync: + ptr(handler)->handle(expr->as()); + return; + case ExprType::InitMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::UpdateMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::ForLoop: + ptr(handler)->handle(expr->as()); + return; + case ExprType::IfThenElse: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridReduction: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridBroadcast: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridWelford: + ptr(handler)->handle(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -148,6 +184,10 @@ void Val::constDispatch(T handler, const Val* val) { break; } break; + case ValType::NamedScalar: + ptr(handler)->handle(val->as()); + return; + case ValType::IterDomain: ptr(handler)->handle(val->as()); return; @@ -157,8 +197,11 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::TensorView: ptr(handler)->handle(val->as()); return; - case ValType::NamedScalar: - ptr(handler)->handle(val->as()); + case ValType::Predicate: + ptr(handler)->handle(val->as()); + return; + case ValType::TensorIndex: + ptr(handler)->handle(val->as()); return; default: break; @@ -169,12 +212,6 @@ void Val::constDispatch(T handler, const Val* val) { template void Expr::constDispatch(T handler, const Expr* expr) { switch (*(expr->getExprType())) { - case ExprType::Split: - ptr(handler)->handle(expr->as()); - return; - case ExprType::Merge: - ptr(handler)->handle(expr->as()); - return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -193,6 +230,13 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Split: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Merge: + ptr(handler)->handle(expr->as()); + return; case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; @@ -205,6 +249,34 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::ViewOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Allocate: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Sync: + ptr(handler)->handle(expr->as()); + return; + case ExprType::InitMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::UpdateMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::ForLoop: + ptr(handler)->handle(expr->as()); + return; + case ExprType::IfThenElse: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridReduction: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridBroadcast: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridWelford: + ptr(handler)->handle(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -232,28 +304,42 @@ void Statement::constDispatch(T handler, const Statement* stmt) { * ptr(mutator)->mutate(this->as()); */ template -Statement* Val::mutatorDispatch(T mutator, Val* val) { +void Val::mutatorDispatch(T mutator, Val* val) { switch (*(val->getValType())) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case DataType::Double: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case DataType::Int: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; default: break; } break; + case ValType::NamedScalar: + ptr(mutator)->mutate(val->as()); + return; + case ValType::IterDomain: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case ValType::TensorDomain: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case ValType::TensorView: - return ptr(mutator)->mutate(val->as()); - case ValType::NamedScalar: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; + case ValType::Predicate: + ptr(mutator)->mutate(val->as()); + return; + case ValType::TensorIndex: + ptr(mutator)->mutate(val->as()); + return; default: break; } @@ -261,44 +347,87 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) { } template -Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { +void Expr::mutatorDispatch(T mutator, Expr* expr) { switch (*(expr->getExprType())) { - case ExprType::Split: - return ptr(mutator)->mutate(expr->as()); - case ExprType::Merge: - return ptr(mutator)->mutate(expr->as()); case ExprType::UnaryOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::BinaryOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::TernaryOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ReductionOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::WelfordOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::BroadcastOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; + + case ExprType::Split: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::Merge: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::TransposeOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ShiftOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::GatherOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ViewOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; + + case ExprType::Allocate: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::Sync: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::InitMagicZero: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::UpdateMagicZero: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::ForLoop: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::IfThenElse: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::GridReduction: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::GridBroadcast: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::GridWelford: + ptr(mutator)->mutate(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } } template -Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) { +void Statement::mutatorDispatch(T mutator, Statement* stmt) { if (stmt->isVal()) { - return ptr(mutator)->mutate(stmt->as()); + ptr(mutator)->mutate(stmt->as()); + return; } if (stmt->isExpr()) { - return ptr(mutator)->mutate(stmt->as()); + ptr(mutator)->mutate(stmt->as()); + return; } TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -308,11 +437,11 @@ Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) { * classes. Actual visitors/mutators should inhereit from these classes and call * ->dispatch(this) to avoid needing an explicit instantiation. */ -template void Statement::dispatch(OptOutDispatch, Statement*); +template void Statement::dispatch(OptOutDispatch&, Statement*); template void Statement::dispatch(OptOutDispatch*, Statement*); -template void Val::dispatch(OptOutDispatch, Val*); +template void Val::dispatch(OptOutDispatch&, Val*); template void Val::dispatch(OptOutDispatch*, Val*); -template void Expr::dispatch(OptOutDispatch, Expr*); +template void Expr::dispatch(OptOutDispatch&, Expr*); template void Expr::dispatch(OptOutDispatch*, Expr*); template void Statement::dispatch(OptInDispatch, Statement*); @@ -322,33 +451,26 @@ template void Val::dispatch(OptInDispatch*, Val*); template void Expr::dispatch(OptInDispatch, Expr*); template void Expr::dispatch(OptInDispatch*, Expr*); -template void Statement::constDispatch(OptOutConstDispatch, const Statement*); +template void Statement::constDispatch(OptOutConstDispatch&, const Statement*); template void Statement::constDispatch(OptOutConstDispatch*, const Statement*); -template void Val::constDispatch(OptOutConstDispatch, const Val*); +template void Val::constDispatch(OptOutConstDispatch&, const Val*); template void Val::constDispatch(OptOutConstDispatch*, const Val*); -template void Expr::constDispatch(OptOutConstDispatch, const Expr*); +template void Expr::constDispatch(OptOutConstDispatch&, const Expr*); template void Expr::constDispatch(OptOutConstDispatch*, const Expr*); -template void Statement::constDispatch(OptInConstDispatch, const Statement*); +template void Statement::constDispatch(OptInConstDispatch&, const Statement*); template void Statement::constDispatch(OptInConstDispatch*, const Statement*); -template void Val::constDispatch(OptInConstDispatch, const Val*); +template void Val::constDispatch(OptInConstDispatch&, const Val*); template void Val::constDispatch(OptInConstDispatch*, const Val*); -template void Expr::constDispatch(OptInConstDispatch, const Expr*); +template void Expr::constDispatch(OptInConstDispatch&, const Expr*); template void Expr::constDispatch(OptInConstDispatch*, const Expr*); -template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*); -template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*); -template Statement* Val::mutatorDispatch(OptOutMutator, Val*); -template Statement* Val::mutatorDispatch(OptOutMutator*, Val*); -template Statement* Expr::mutatorDispatch(OptOutMutator, Expr*); -template Statement* Expr::mutatorDispatch(OptOutMutator*, Expr*); - -template Statement* Statement::mutatorDispatch(OptInMutator, Statement*); -template Statement* Statement::mutatorDispatch(OptInMutator*, Statement*); -template Statement* Val::mutatorDispatch(OptInMutator, Val*); -template Statement* Val::mutatorDispatch(OptInMutator*, Val*); -template Statement* Expr::mutatorDispatch(OptInMutator, Expr*); -template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*); +template void Statement::mutatorDispatch(OptOutMutator&, Statement*); +template void Statement::mutatorDispatch(OptOutMutator*, Statement*); +template void Val::mutatorDispatch(OptOutMutator&, Val*); +template void Val::mutatorDispatch(OptOutMutator*, Val*); +template void Expr::mutatorDispatch(OptOutMutator&, Expr*); +template void Expr::mutatorDispatch(OptOutMutator*, Expr*); void OptOutDispatch::handle(Statement* s) { Statement::dispatch(this, s); @@ -362,18 +484,6 @@ void OptOutDispatch::handle(Val* v) { Val::dispatch(this, v); } -void OptInDispatch::handle(Statement* s) { - Statement::dispatch(this, s); -} - -void OptInDispatch::handle(Expr* e) { - Expr::dispatch(this, e); -} - -void OptInDispatch::handle(Val* v) { - Val::dispatch(this, v); -} - void OptOutConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } @@ -386,46 +496,224 @@ void OptOutConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } -void OptInConstDispatch::handle(const Statement* s) { - Statement::constDispatch(this, s); +void OptInConstDispatch::unhandled(const Statement* stmt) { + if (stmt->isExpr()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getExprType().value(), "."); + } else if (stmt->isVal()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getValType().value(), "."); + } else { + TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type."); + } } -void OptInConstDispatch::handle(const Expr* e) { - Expr::constDispatch(this, e); +void OptInDispatch::unhandled(Statement* stmt) { + if (stmt->isExpr()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getExprType().value(), "."); + } else if (stmt->isVal()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getValType().value(), "."); + } else { + TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type."); + } } -void OptInConstDispatch::handle(const Val* v) { - Val::constDispatch(this, v); +// Vals +void OptOutConstDispatch::handle(const Bool* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const Double* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const Int* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const NamedScalar* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const IterDomain* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TensorDomain* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TensorView* stmt) { + unhandled(stmt); +} + +void OptOutConstDispatch::handle(const kir::Predicate* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) { + unhandled(stmt); +} + +// Exprs +void OptOutConstDispatch::handle(const UnaryOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const BinaryOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TernaryOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const ReductionOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const WelfordOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const BroadcastOp* stmt) { + unhandled(stmt); } -Statement* OptInMutator::mutate(Statement* s) { - return Statement::mutatorDispatch(this, s); +void OptOutConstDispatch::handle(const Split* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const Merge* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TransposeOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const ShiftOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const GatherOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const ViewOp* stmt) { + unhandled(stmt); } -Statement* OptInMutator::mutate(Expr* e) { - return Expr::mutatorDispatch(this, e); +void OptOutConstDispatch::handle(const kir::Allocate* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::Sync* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::UpdateMagicZero* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::ForLoop* stmt) { + unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::IfThenElse* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::GridReduction* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::GridWelford* stmt) { + unhandled(stmt); +} + +void OptOutDispatch::unhandled(Statement*) {} -Statement* OptInMutator::mutate(Val* v) { - // If value is already mutated, return the mutation - if (mutations.find(v) != mutations.end()) - return mutations[v]; - return Val::mutatorDispatch(this, v); +// Vals +void OptOutDispatch::handle(Bool* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(Double* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(Int* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(NamedScalar* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(IterDomain* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TensorDomain* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TensorView* stmt) { + unhandled(stmt); } -Statement* OptOutMutator::mutate(Statement* s) { - return Statement::mutatorDispatch(this, s); +void OptOutDispatch::handle(kir::Predicate* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::TensorIndex* stmt) { + unhandled(stmt); } -Statement* OptOutMutator::mutate(Expr* e) { - return Expr::mutatorDispatch(this, e); +// Exprs +void OptOutDispatch::handle(UnaryOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(BinaryOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TernaryOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(ReductionOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(WelfordOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(BroadcastOp* stmt) { + unhandled(stmt); } -Statement* OptOutMutator::mutate(Val* v) { - // If value is already mutated, return the mutation - if (mutations.find(v) != mutations.end()) - return mutations[v]; - return Val::mutatorDispatch(this, v); +void OptOutDispatch::handle(Split* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(Merge* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TransposeOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(ShiftOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(GatherOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(ViewOp* stmt) { + unhandled(stmt); +} + +void OptOutDispatch::handle(kir::Allocate* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::Sync* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::InitMagicZero* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::UpdateMagicZero* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::ForLoop* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::IfThenElse* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::GridReduction* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::GridBroadcast* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::GridWelford* stmt) { + unhandled(stmt); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index c1be76eb950..6961ebd6a15 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -1,9 +1,9 @@ #pragma once -#include - +#include #include -#include + +#include #include @@ -48,7 +48,7 @@ namespace torch { namespace jit { namespace fuser { namespace cuda { - +class IrContainer; class Fusion; // Hierarchal dispatch functions for handle @@ -60,14 +60,13 @@ class Val; class IterDomain; class TensorDomain; class TensorView; + class Bool; class Double; class Int; class NamedScalar; // Exprs -class Split; -class Merge; class UnaryOp; class BinaryOp; class TernaryOp; @@ -79,9 +78,35 @@ class ShiftOp; class GatherOp; class ViewOp; +// Exprs +class Split; +class Merge; +class TransposeOp; +class ShiftOp; +class GatherOp; +class ViewOp; + +namespace kir { +class Predicate; +class TensorIndex; + +class Allocate; +class Sync; +class ForLoop; +class IfThenElse; +class GridReduction; +class GridBroadcast; +class GridWelford; +class InitMagicZero; +class UpdateMagicZero; +} // namespace kir + // By default, all IR nodes are handled in this dispatch, and will call an empty // function on all nodes. class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { + protected: + virtual void unhandled(const Statement*) {} + public: // Hierarchal dispatch functions for handle virtual void handle(const Statement*); @@ -89,30 +114,47 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const Val*); // Vals - virtual void handle(const IterDomain*) {} - virtual void handle(const TensorDomain*) {} - virtual void handle(const TensorView*) {} - virtual void handle(const Bool*) {} - virtual void handle(const Double*) {} - virtual void handle(const Int*) {} - virtual void handle(const NamedScalar*) {} + virtual void handle(const IterDomain* stmt); + virtual void handle(const TensorDomain* stmt); + virtual void handle(const TensorView* stmt); + virtual void handle(const Bool* stmt); + virtual void handle(const Double* stmt); + virtual void handle(const Int* stmt); + virtual void handle(const NamedScalar* stmt); + + virtual void handle(const kir::Predicate*); + virtual void handle(const kir::TensorIndex*); // Exprs - virtual void handle(const Split*) {} - virtual void handle(const Merge*) {} - virtual void handle(const UnaryOp*) {} - virtual void handle(const BinaryOp*) {} - virtual void handle(const TernaryOp*) {} - virtual void handle(const ReductionOp*) {} - virtual void handle(const WelfordOp*) {} - virtual void handle(const BroadcastOp*) {} - virtual void handle(const TransposeOp*) {} - virtual void handle(const ShiftOp*) {} - virtual void handle(const GatherOp*) {} - virtual void handle(const ViewOp*) {} + virtual void handle(const UnaryOp* stmt); + virtual void handle(const BinaryOp* stmt); + virtual void handle(const TernaryOp* stmt); + virtual void handle(const ReductionOp* stmt); + virtual void handle(const WelfordOp* stmt); + virtual void handle(const BroadcastOp* stmt); + + virtual void handle(const Split* stmt); + virtual void handle(const Merge* stmt); + virtual void handle(const TransposeOp* stmt); + virtual void handle(const ShiftOp* stmt); + virtual void handle(const GatherOp* stmt); + virtual void handle(const ViewOp* stmt); + + virtual void handle(const kir::Allocate*); + virtual void handle(const kir::Sync*); + virtual void handle(const kir::InitMagicZero*); + virtual void handle(const kir::UpdateMagicZero*); + virtual void handle(const kir::ForLoop*); + virtual void handle(const kir::IfThenElse*); + virtual void handle(const kir::GridReduction*); + virtual void handle(const kir::GridBroadcast*); + virtual void handle(const kir::GridWelford*); }; class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { + protected: + virtual void unhandled(Statement*); + public: // Hierarchal dispatch functions for handle virtual void handle(Statement*); @@ -120,190 +162,88 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(Val*); // Vals - virtual void handle(IterDomain*) {} - virtual void handle(TensorDomain*) {} - virtual void handle(TensorView*) {} - virtual void handle(Bool*) {} - virtual void handle(Double*) {} - virtual void handle(Int*) {} - virtual void handle(NamedScalar*) {} + virtual void handle(Bool* stmt); + virtual void handle(Double* stmt); + virtual void handle(Int* stmt); + virtual void handle(NamedScalar* stmt); + virtual void handle(IterDomain* stmt); + virtual void handle(TensorDomain* stmt); + virtual void handle(TensorView* stmt); + + virtual void handle(kir::Predicate*); + virtual void handle(kir::TensorIndex*); // Exprs - virtual void handle(Split*) {} - virtual void handle(Merge*) {} - virtual void handle(UnaryOp*) {} - virtual void handle(BinaryOp*) {} - virtual void handle(TernaryOp*) {} - virtual void handle(ReductionOp*) {} - virtual void handle(WelfordOp*) {} - virtual void handle(BroadcastOp*) {} - virtual void handle(TransposeOp*) {} - virtual void handle(ShiftOp*) {} - virtual void handle(GatherOp*) {} - virtual void handle(ViewOp*) {} + virtual void handle(UnaryOp* stmt); + virtual void handle(BinaryOp* stmt); + virtual void handle(TernaryOp* stmt); + virtual void handle(ReductionOp* stmt); + virtual void handle(WelfordOp* stmt); + virtual void handle(BroadcastOp* stmt); + + virtual void handle(Split* stmt); + virtual void handle(Merge* stmt); + virtual void handle(TransposeOp* stmt); + virtual void handle(ShiftOp* stmt); + virtual void handle(GatherOp* stmt); + virtual void handle(ViewOp* stmt); + + virtual void handle(kir::Allocate* stmt); + virtual void handle(kir::Sync* stmt); + virtual void handle(kir::InitMagicZero* stmt); + virtual void handle(kir::UpdateMagicZero* stmt); + virtual void handle(kir::ForLoop* stmt); + virtual void handle(kir::IfThenElse* stmt); + virtual void handle(kir::GridReduction* stmt); + virtual void handle(kir::GridBroadcast* stmt); + virtual void handle(kir::GridWelford* stmt); }; -class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase { +class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch { public: - // Hierarchal dispatch functions for handle - virtual void handle(const Statement*); - virtual void handle(const Expr*); - virtual void handle(const Val*); - - // Vals - virtual void handle(const IterDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain."); - } - virtual void handle(const TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain."); - } - virtual void handle(const TensorView*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView."); - } - virtual void handle(const Bool*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool."); - } - virtual void handle(const Double*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double."); - } - virtual void handle(const Int*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int."); - } - virtual void handle(const NamedScalar*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar."); - } + using OptOutConstDispatch::handle; - // Exprs - virtual void handle(const Split*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split."); - } - virtual void handle(const Merge*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); - } - virtual void handle(const UnaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); - } - virtual void handle(const BinaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp."); - } - virtual void handle(const WelfordOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp."); - } - virtual void handle(const TernaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp."); - } - virtual void handle(const ReductionOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); - } - virtual void handle(const BroadcastOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); - } - virtual void handle(const TransposeOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp."); - } - virtual void handle(const ShiftOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp."); - } - virtual void handle(const GatherOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp."); - } - virtual void handle(const ViewOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ViewOp."); - } + protected: + virtual void unhandled(const Statement* stmt) final; }; -class TORCH_CUDA_CU_API OptInDispatch : public PolymorphicBase { +class TORCH_CUDA_CU_API OptInDispatch : public OptOutDispatch { public: - // Hierarchal dispatch functions for handle - virtual void handle(Statement* s); - virtual void handle(Expr* e); - virtual void handle(Val* v); - - // Vals - virtual void handle(IterDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain."); - } - virtual void handle(TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain."); - } - virtual void handle(TensorView*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView."); - } - virtual void handle(Bool*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool."); - } - virtual void handle(Double*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double."); - } - virtual void handle(Int*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int."); - } - virtual void handle(NamedScalar*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar."); - } + using OptOutDispatch::handle; - // Exprs - virtual void handle(Split*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split."); - } - virtual void handle(Merge*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); - } - virtual void handle(UnaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); - } - virtual void handle(BinaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp."); - } - virtual void handle(TernaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp."); - } - virtual void handle(ReductionOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); - } - virtual void handle(WelfordOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp."); - } - virtual void handle(BroadcastOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); - } - virtual void handle(TransposeOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp."); - } - virtual void handle(ShiftOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp."); - } - virtual void handle(GatherOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp."); - } - virtual void handle(ViewOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ViewOp."); - } + protected: + virtual void unhandled(Statement* stmt) final; }; +// Class to perform mutations on Fusion IR. Exprs can simply be redefined, but +// when mutating values they have to be registered through registerMutation so +// that exprs can detect there's been a muatation and know to modify all +// instances of that Val. This means each Val should be mutated "consistently". +// Otherwise behavior may be difficult to understand as it depends on which +// order mutate is called in. This class expects user to topologically call the +// statments of interest so inputs are called and mutated before exprs depending +// on them. +// +// Warning: TensorViews need to be treated carefully. As we don't generally +// register their mutation when their tensor domains only change. If a TV needs +// to be swapped out, it needs to be registered as a "proper" mutation like +// other vals, on top of TensorDomain being updated in the mutated TensorView. +// // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { public: // Hierarchal dispatch functions for handle - virtual Statement* mutate(Statement* s); - virtual Statement* mutate(Expr* e); - virtual Statement* mutate(Val* v); - - // We always want to dispatch through a Val, so we can capture and dispatch - // correctly members of nodes like Split->TensorDomain If we don't call the - // below function or manually cast to use mutate(Val* v) we can't intercept - // and mutate by capturing mutate(Val* v), which is what we do when we want to - // replace all instances of a value. - Statement* mutateAsVal(Val* v) { - return mutate(v); - } + virtual void mutate(Statement* s); + virtual void mutate(Expr* e); + virtual void mutate(Val* v); + + void registerMutation(Val* val, Val* mutation); - void registerMutation(Val* val, Val* mutation) { - TORCH_INTERNAL_ASSERT( - mutations.find(val) == mutations.end(), - " The same value is incorrectly being mutated twice.", - " One mutation per mutation pass is allowed."); - mutations[val] = mutation; + Val* maybeMutated(Val* val) { + if (mutations.find(val) == mutations.end()) { + return val; + } + return mutations.at(val); } std::unordered_map mutations; @@ -311,105 +251,44 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { //****Functions below defined in mutator.cpp***** // Vals - virtual Statement* mutate(IterDomain*); - virtual Statement* mutate(TensorDomain*); - virtual Statement* mutate(TensorView*); - virtual Statement* mutate(Bool*); - virtual Statement* mutate(Double*); - virtual Statement* mutate(Int*); - virtual Statement* mutate(NamedScalar*); - - // Exprs - virtual Statement* mutate(Split*); - virtual Statement* mutate(Merge*); - virtual Statement* mutate(UnaryOp*); - virtual Statement* mutate(BinaryOp*); - virtual Statement* mutate(TernaryOp*); - virtual Statement* mutate(ReductionOp*); - virtual Statement* mutate(WelfordOp*); - virtual Statement* mutate(BroadcastOp*); - virtual Statement* mutate(TransposeOp*); - virtual Statement* mutate(ShiftOp*); - virtual Statement* mutate(GatherOp*); - virtual Statement* mutate(ViewOp*); -}; - -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase { - public: - std::unordered_map mutations; - - public: - void registerMutation(Val* val, Val* mutation) { - TORCH_INTERNAL_ASSERT( - mutations.find(val) == mutations.end(), - " The same value is incorrectly being mutated twice.", - " One mutation per mutation pass is allowed."); - mutations[val] = mutation; - } - - // Hierarchal dispatch functions for mutate - virtual Statement* mutate(Statement*); - virtual Statement* mutate(Expr*); - virtual Statement* mutate(Val*); + virtual void mutate(Bool*); + virtual void mutate(Double*); + virtual void mutate(Int*); + virtual void mutate(NamedScalar*); + virtual void mutate(IterDomain*); + virtual void mutate(TensorDomain*); + virtual void mutate(TensorView*); - // Vals - virtual Statement* mutate(IterDomain*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for IterDomain."); - } - virtual Statement* mutate(TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorDomain."); - } - virtual Statement* mutate(TensorView*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorView."); - } - virtual Statement* mutate(Bool*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Bool."); - } - virtual Statement* mutate(Int*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Int."); - } - virtual Statement* mutate(NamedScalar*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for NamedScalar."); - } + virtual void mutate(kir::Predicate*); + virtual void mutate(kir::TensorIndex*); // Exprs - virtual Statement* mutate(Split*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Split."); - } - virtual Statement* mutate(Merge*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Merge."); - } - virtual Statement* mutate(UnaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for UnaryOp."); - } - virtual Statement* mutate(BinaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BinaryOp."); - } - virtual Statement* mutate(TernaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TernaryOp."); - } - virtual Statement* mutate(ReductionOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ReductionOp."); - } - virtual Statement* mutate(WelfordOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for WelfordOp."); - } - virtual Statement* mutate(BroadcastOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BroadcastOp."); - } - virtual Statement* mutate(TransposeOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TransposeOp."); - } - virtual Statement* mutate(ShiftOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ShiftOp."); - } - virtual Statement* mutate(GatherOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GatherOp."); - } - virtual Statement* mutate(ViewOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ViewOp."); - } + virtual void mutate(UnaryOp*); + virtual void mutate(BinaryOp*); + virtual void mutate(TernaryOp*); + virtual void mutate(ReductionOp*); + virtual void mutate(WelfordOp*); + virtual void mutate(BroadcastOp*); + + virtual void mutate(Split*); + virtual void mutate(Merge*); + virtual void mutate(TransposeOp*); + virtual void mutate(ShiftOp*); + virtual void mutate(GatherOp*); + virtual void mutate(ViewOp*); + + virtual void mutate(kir::Allocate*); + virtual void mutate(kir::Sync*); + virtual void mutate(kir::InitMagicZero*); + virtual void mutate(kir::UpdateMagicZero*); + virtual void mutate(kir::ForLoop*); + virtual void mutate(kir::IfThenElse*); + virtual void mutate(kir::GridReduction*); + virtual void mutate(kir::GridBroadcast*); + virtual void mutate(kir::GridWelford*); + + protected: + void removeExpr(IrContainer*, Expr*); }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp index 288dbb198b0..09481319569 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp @@ -1,9 +1,11 @@ -#include #include #include +#include #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -68,8 +70,8 @@ std::vector makeSortedEvaluationList(std::vector input) { //! Kernel IR utility, collects all the symbolic integers //! used in allocation nodes. void collectBufferSizes( - std::vector& into, - const std::vector& exprs) { + std::vector& into, + const std::vector& exprs) { for (auto expr : exprs) { if (auto allocate = dynamic_cast(expr)) { into.push_back(allocate->size()); @@ -82,56 +84,44 @@ void collectBufferSizes( } } -//! Kernel IR utility, collects all the kir symbolic +//! Kernel IR utility, collects all the kernel symbolic //! integers we will need at runtime, i.e. after the //! generated cuda kernel has already been compiled. //! The values are to be used for runtime logic, like //! `computeLaunchparams`. -std::vector collectRuntimeUsedIntegers( - Fusion* fusion, - GpuLower* lower) { - std::vector ret; - +std::vector collectRuntimeUsedIntegers(kir::Kernel* kernel) { + std::vector ret; + auto all_tvs = ir_utils::allTvs(kernel); // Collect extent and integer inputs - for (auto val : fusion->usedMathVals()) { - auto kir_val = lower->lowerValue(val); - if (auto kir_tv = dynamic_cast(kir_val)) { - for (auto id : kir_tv->domain()->domain()) { - ret.push_back(id->extent()); - } - } else if (val->isFusionInput()) { - if (kir_val->isA()) { - ret.push_back(kir_val); - } + for (auto tv : all_tvs) { + for (auto id : tv->domain()->domain()) { + ret.push_back(id->extent()); + } + } + for (auto inp : kernel->inputs()) { + if (inp->isA()) { + ret.push_back(inp); } } - // Collect allocation sizes: - collectBufferSizes(ret, lower->kernel()->topLevelExprs()); - + collectBufferSizes(ret, kernel->topLevelExprs()); return makeSortedEvaluationList(ret); } -//! Fusion IR utility, collects all the fusionIR symbolic -//! integers we will need at runtime, i.e. after the -//! generated cuda kernel has already been compiled. -//! The values are to be used for runtime logic, like -//! `canSchedule` in heuristic look up. + std::vector collectRuntimeUsedIntegers(Fusion* fusion) { std::vector ret; - + auto all_tvs = ir_utils::allTvs(fusion); // Collect extent and integer inputs - for (auto val : fusion->usedMathVals()) { - if (auto tv = dynamic_cast(val)) { - for (auto id : tv->domain()->domain()) { - ret.push_back(id->extent()); - } - } else if (val->isFusionInput()) { - if (val->isA()) { - ret.push_back(val); - } + for (auto tv : all_tvs) { + for (auto id : tv->domain()->domain()) { + ret.push_back(id->extent()); + } + } + for (auto inp : fusion->inputs()) { + if (inp->isA()) { + ret.push_back(inp); } } - return makeSortedEvaluationList(ret); } @@ -140,7 +130,7 @@ std::vector collectRuntimeUsedIntegers(Fusion* fusion) { template void PrecomputedIntegersBase::initializeValueList( typename IRContext::EVALUATOR_TYPE& const_evaluator, - const std::vector& sorted_value_list) { + const std::vector& sorted_value_list) { // Initialize workspace num_of_values_ = sorted_value_list.size(); defined_ = std::vector(num_of_values_, false); @@ -161,7 +151,7 @@ void PrecomputedIntegersBase::initializeValueList( template c10::optional PrecomputedIntegersBase::getMaybeValueFor( - const IR_VAL* val) { + const Val* val) { auto index = val->evaluatorIndex(); if (index < 0) { return c10::nullopt; @@ -172,6 +162,17 @@ c10::optional PrecomputedIntegersBase::getMaybeValueFor( return values_[index]; } +template +void PrecomputedIntegersBase::print() const { + std::cout << "Precomputed Integers:\n"; + for (auto i : c10::irange(symbols_.size())) { + if (defined_[i]) { + std::cout << symbols_[i]->toInlineString() << " = " << values_[i] + << std::endl; + } + } +} + template void PrecomputedIntegersBase::evaluate() { FUSER_PERF_SCOPE("PrecomputedIntegers::Evaluate"); @@ -208,10 +209,9 @@ NaiveIntegerMachine::NaiveIntegerMachine( for (auto val : precomputed_integers_.symbols_) { auto def = val->definition(); if (def) { - if (auto uop = dynamic_cast(def)) { + if (auto uop = dynamic_cast(def)) { makeUnaryOp(uop); - } else if ( - auto bop = dynamic_cast(def)) { + } else if (auto bop = dynamic_cast(def)) { makeBinaryOp(bop); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr"); @@ -234,8 +234,7 @@ void NaiveIntegerMachine::run() { } template -void NaiveIntegerMachine::makeUnaryOp( - typename IRContext::UNARY_OP_TYPE* uop) { +void NaiveIntegerMachine::makeUnaryOp(UnaryOp* uop) { int in = uop->inputs()[0]->evaluatorIndex(); int out = uop->outputs()[0]->evaluatorIndex(); TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: ", uop); @@ -249,8 +248,7 @@ void NaiveIntegerMachine::makeUnaryOp( } template -void NaiveIntegerMachine::makeBinaryOp( - typename IRContext::BINARY_OP_TYPE* bop) { +void NaiveIntegerMachine::makeBinaryOp(BinaryOp* bop) { int in0 = bop->inputs()[0]->evaluatorIndex(); int in1 = bop->inputs()[1]->evaluatorIndex(); int out = bop->outputs()[0]->evaluatorIndex(); @@ -377,11 +375,8 @@ void NaiveIntegerMachine::runBinaryOp(int index) { precomputed_integers_.defined_[dest_index] = true; } -KernelPrecomputedIntegers::KernelPrecomputedIntegers( - Fusion* fusion, - GpuLower& lower) - : lower_(&lower) { - loadSymbols(collectRuntimeUsedIntegers(fusion, lower_)); +KernelPrecomputedIntegers::KernelPrecomputedIntegers(kir::Kernel* kernel) { + loadSymbols(collectRuntimeUsedIntegers(kernel)); kir::ExpressionEvaluator evaluator; initializeValueList(evaluator, symbols()); initializeNamedScalars(); @@ -389,11 +384,11 @@ KernelPrecomputedIntegers::KernelPrecomputedIntegers( } void KernelPrecomputedIntegers::bindTensorMetaData( - kir::TensorView* tv, + TensorView* tv, const at::Tensor& at_tensor) { - std::vector> ret; + std::vector> ret; const auto root_domain = - kir::TensorDomain::noReductions(tv->domain()->rootDomain()); + TensorDomain::noReductions(tv->domain()->getRootDomain()); TORCH_INTERNAL_ASSERT( at_tensor.ndimension() == static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs do not match."); @@ -411,7 +406,7 @@ namespace { //! and returns the corresponding parallel type if a match //! is found. c10::optional getMaybeThreadSizeParallelType( - kir::NamedScalar* named_scalar) { + NamedScalar* named_scalar) { auto& var_name = named_scalar->name(); for (auto ptype : kParallelTypeThreads) { if (var_name == stringifyThreadSize(ptype)) { @@ -425,7 +420,7 @@ c10::optional getMaybeThreadSizeParallelType( void KernelPrecomputedIntegers::initializeNamedScalars() { for (auto val : symbols()) { - if (auto named_scalar = dynamic_cast(val)) { + if (auto named_scalar = dynamic_cast(val)) { auto maybe_parallel_type = getMaybeThreadSizeParallelType(named_scalar); if (maybe_parallel_type.has_value()) { auto& index_list = @@ -440,17 +435,17 @@ void KernelPrecomputedIntegers::initializeNamedScalars() { } void KernelPrecomputedIntegers::bindKernelInputs( + kir::Kernel* kernel, const at::ArrayRef& aten_inputs) { if (hasValidValues()) { invalidate(); } - auto kernel = lower_->kernel(); const auto& inputs = kernel->inputs(); for (const auto i : c10::irange(inputs.size())) { const auto input = inputs[i]; - if (auto tensor_input = dynamic_cast(input)) { + if (auto tensor_input = dynamic_cast(input)) { const auto aten_tensor = aten_inputs[i].toTensor(); bindTensorMetaData(tensor_input, aten_tensor); } else if (input->isScalar() && input->dtype() == DataType::Int) { diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.h b/torch/csrc/jit/codegen/cuda/evaluator_common.h index 0c16e2a8b04..7cbe37c602b 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.h +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.h @@ -35,18 +35,14 @@ class ExpressionEvaluator; //! Context for using generic logic on FusionIR class FusionIRContext { public: - using VAL_TYPE = Val; - using EXPR_TYPE = Expr; using TV_TYPE = TensorView; using EVALUATOR_TYPE = ExpressionEvaluator; - using BINARY_OP_TYPE = BinaryOp; - using UNARY_OP_TYPE = UnaryOp; - static BinaryOpType getOpType(BINARY_OP_TYPE* bop) { + static BinaryOpType getOpType(BinaryOp* bop) { return bop->getBinaryOpType(); } - static UnaryOpType getOpType(UNARY_OP_TYPE* uop) { + static UnaryOpType getOpType(UnaryOp* uop) { return uop->getUnaryOpType(); } }; @@ -54,19 +50,14 @@ class FusionIRContext { //! Context for using generic logic on KernelIR class KernelIRContext { public: - using VAL_TYPE = kir::Val; - using EXPR_TYPE = kir::Expr; - using TV_TYPE = kir::TensorView; using EVALUATOR_TYPE = kir::ExpressionEvaluator; - using BINARY_OP_TYPE = kir::BinaryOp; - using UNARY_OP_TYPE = kir::UnaryOp; - static BinaryOpType getOpType(BINARY_OP_TYPE* bop) { - return bop->operation(); + static BinaryOpType getOpType(BinaryOp* bop) { + return bop->getBinaryOpType(); } - static UnaryOpType getOpType(UNARY_OP_TYPE* uop) { - return uop->operation(); + static UnaryOpType getOpType(UnaryOp* uop) { + return uop->getUnaryOpType(); } }; @@ -97,10 +88,10 @@ class NaiveIntegerMachine { private: //! Convert an unary IR expr to an instruction - void makeUnaryOp(typename IRContext::UNARY_OP_TYPE* uop); + void makeUnaryOp(UnaryOp* uop); //! Convert an binary IR expr to an instruction - void makeBinaryOp(typename IRContext::BINARY_OP_TYPE* bop); + void makeBinaryOp(BinaryOp* bop); //! Create an empty instruction with all default values //! and place it at the end of the instruction buffer. @@ -169,11 +160,6 @@ class NaiveIntegerMachine { //! integers and store them in the workspace ahead of time. template class PrecomputedIntegersBase { - using IR_UNARY_OP = typename IRContext::UNARY_OP_TYPE; - using IR_BINARY_OP = typename IRContext::BINARY_OP_TYPE; - using IR_VAL = typename IRContext::VAL_TYPE; - using IR_EXPR = typename IRContext::EXPR_TYPE; - using IR_TV = typename IRContext::TV_TYPE; using INTEGER_MACHINE = NaiveIntegerMachine; public: @@ -190,7 +176,10 @@ class PrecomputedIntegersBase { //! Returns value for the given IR node if it's stored //! in the workspace and has been evaluated. - c10::optional getMaybeValueFor(const IR_VAL* val); + c10::optional getMaybeValueFor(const Val* val); + + //! Debugging helper, prints all the currently known values + void print() const; protected: //! Initialize the workspace before first use. @@ -198,7 +187,7 @@ class PrecomputedIntegersBase { //! been topologically sorted. void initializeValueList( typename IRContext::EVALUATOR_TYPE& evaluator, - const std::vector& sorted_value_list); + const std::vector& sorted_value_list); //! Bind concrete value to the given index //! if the index is valid. @@ -215,12 +204,12 @@ class PrecomputedIntegersBase { void invalidate(); //! Interface for subclasses to access symbols_ - void loadSymbols(std::vector symbols) { + void loadSymbols(std::vector symbols) { symbols_ = std::move(symbols); } //! Interface for subclasses to access symbols_ - std::vector& symbols() { + std::vector& symbols() { return symbols_; } @@ -267,7 +256,7 @@ class PrecomputedIntegersBase { std::vector values_; //! Stores the IR nodes corresponding to each index. - std::vector symbols_; + std::vector symbols_; //! An internal log to keep track of all the bindings //! used in each evaluation cycle. To be used for @@ -308,12 +297,14 @@ class KernelPrecomputedIntegers public: using ParallelExtentMap = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; - KernelPrecomputedIntegers(Fusion* fusion, GpuLower& lower); + KernelPrecomputedIntegers(kir::Kernel* kernel); //! Bind concrete values from fusion runtime inputs - void bindKernelInputs(const at::ArrayRef& aten_inputs); + void bindKernelInputs( + kir::Kernel* kernel, + const at::ArrayRef& aten_inputs); //! Bind concrete values from launch constraints void bindParallelExtents( @@ -326,7 +317,7 @@ class KernelPrecomputedIntegers void bindConcreteParallelTypeValue(ParallelType pt, int64_t value); private: - void bindTensorMetaData(kir::TensorView* tv, const at::Tensor& at_tensor); + void bindTensorMetaData(TensorView* tv, const at::Tensor& at_tensor); //! Iterate through all the named scalars corresponding //! to thread sizes and pre-group them by their parallel @@ -334,8 +325,6 @@ class KernelPrecomputedIntegers void initializeNamedScalars(); private: - GpuLower* lower_ = nullptr; - //! Contains all the named scalars correspond //! to thread size of each parallel type. std::unordered_map>, TypeHash> diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 647cf4ec0e2..5e6f2d9375e 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1,3 +1,4 @@ + #include #include @@ -8,21 +9,11 @@ #include #include #include -#include #include #include #include #include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - #include #include #include @@ -108,8 +99,6 @@ void FusionExecutor::debugCompileFusionFromStr( const std::string& name, int id, CompileOptions options) { - fusion_ = *fusion; - FusionGuard fg(&fusion_); options_ = options; if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) { @@ -126,11 +115,12 @@ void FusionExecutor::debugCompileFusionFromStr( << std::endl; } - setUsedTVs(); + lowered_ = std::make_unique(fusion); + const auto kernel = lowered_->kernel(); + fusion_ = lowered_->kernel(); fusion_id_ = id; - lowered_ = GpuLower(&fusion_); - const auto kernel = lowered_.kernel(); + setUsedTVs(); if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) { kernel->print(); @@ -155,9 +145,9 @@ void FusionExecutor::debugCompileFusionFromStr( void FusionExecutor::compileFusion( Fusion* fusion, - CompileOptions options, const at::ArrayRef& inputs, - const LaunchParams& launch_constraints) { + const LaunchParams& launch_constraints, + CompileOptions options) { FUSER_PERF_SCOPE("compileFusion"); TORCH_INTERNAL_ASSERT( @@ -175,9 +165,6 @@ void FusionExecutor::compileFusion( fusion->printMath(); } - // Clone the fusion so we can store it - fusion_ = *fusion; - FusionGuard fg(&fusion_); options_ = options; c10::DeviceGuard dg(options_.device); @@ -187,11 +174,12 @@ void FusionExecutor::compileFusion( max_device_smem = properties->sharedMemPerBlock; warp_size_ = properties->warpSize; - setUsedTVs(); + lowered_ = std::make_unique(fusion); + const auto kernel = lowered_->kernel(); + fusion_ = lowered_->kernel()->as(); fusion_id_ = ++fusion_id_counter_; - lowered_ = GpuLower(&fusion_); - const auto kernel = lowered_.kernel(); + setUsedTVs(); if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) { kernel->print(); @@ -216,7 +204,7 @@ void FusionExecutor::compileFusion( std::stringstream ss; ss << "Allocations must be based on constant integers for local memory. However, found: "; for (auto alloc : kernel_summary.dynamic_lmem_allocations) { - ss << toString(alloc->buffer(), false) << ", "; + ss << alloc->buffer()->toString() << ", "; } ss << " have dynamic allocations but are placed in local memory."; TORCH_INTERNAL_ASSERT(false, ss.str()); @@ -233,6 +221,8 @@ void FusionExecutor::compileFusion( block_size > 0, "launch param inferred block size < 0"); } + block_size_high_water_mark = + block_size.has_value() ? block_size.value() : block_size_high_water_mark; compiled_kernel_ = executor_utils::nvrtcCompile( structured_code, (kernelNamespace() + "::" + kernelName()).c_str(), @@ -245,8 +235,8 @@ void FusionExecutor::compileFusion( namespace { at::Tensor inferAndAlloc( - const kir::TensorView* tv, - const std::vector& sizes, + const TensorView* tv, + const std::vector& sizes, kir::ExpressionEvaluator& expr_eval, const CompileOptions& options, bool zero_init = false) { @@ -260,9 +250,11 @@ at::Tensor inferAndAlloc( TORCH_INTERNAL_ASSERT( inferred_val.has_value(), "Could not launch kernel as program could not infer ", - kir::toString(size), - " for the buffer ", - kir::toString(tv)); + size->toString(), + "(", + size->name(), + ") for the buffer ", + tv->toString()); inferred_sizes.push_back(inferred_val.value()); } @@ -283,19 +275,20 @@ at::Tensor inferAndAlloc( } at::Tensor inferAndAllocOutput( - const kir::TensorView* tv, + const TensorView* tv, kir::ExpressionEvaluator& expr_eval, const CompileOptions& options, bool zero_init = false) { const auto domain = tv->domain(); - const auto maybe_rfactor_domain = - domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain(); + const auto maybe_rfactor_domain = domain->hasRFactor() + ? domain->getRFactorDomain() + : domain->getRootDomain(); - std::vector sizes; + std::vector sizes; for (const auto id : maybe_rfactor_domain) { if (id->isReduction() || id->isStride() || - id->iterType() == IterType::BroadcastWithoutStride) { + id->getIterType() == IterType::BroadcastWithoutStride) { continue; } sizes.push_back(id->extent()); @@ -348,8 +341,7 @@ LaunchParams FusionExecutor::computeLaunchParams( auto data_cache = compileTimeDataCache(); - auto& lower = lowered_; - + auto lower = lowered_.get(); auto& used_tvs = getUsedTVs(); auto parallel_binding_ids_entry = executor_utils::caching::ExecutorCompileTimeEntry< @@ -364,9 +356,8 @@ LaunchParams FusionExecutor::computeLaunchParams( auto parallel_iter_extent_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::ParallelIterExtentMap>( - data_cache, [¶llel_binding_ids, &lower]() { - return executor_utils::getParallelIterExtents( - lower, parallel_binding_ids); + data_cache, [¶llel_binding_ids]() { + return executor_utils::getParallelIterExtents(parallel_binding_ids); }); auto& parallel_iter_extents = parallel_iter_extent_entry.get(); @@ -385,7 +376,7 @@ LaunchParams FusionExecutor::computeLaunchParams( executor_utils::caching::WarpPaddedParallelExtents>( data_cache, [¶llel_binding_ids, &lower]() { return executor_utils::getWarpPaddedExtentsInfo( - lower, parallel_binding_ids); + lower->kernel(), parallel_binding_ids); }); auto& warp_padded_extent_set = warp_padded_parallel_entry.get().warp_padded_extent_set; @@ -446,7 +437,9 @@ LaunchParams FusionExecutor::computeLaunchParams( auto val = expr_eval.evaluate(extent); TORCH_INTERNAL_ASSERT( val.has_value(), - "Tried to evaluate the extent of ", + "Tried to evaluate the extent, ", + extent->toInlineString(), + " for the ptype: ", p_type, " to set launch bounds but could not."); @@ -481,14 +474,15 @@ LaunchParams FusionExecutor::computeLaunchParams( expr_eval.precomputedIntegers()->evaluate(); } - const auto kernel = lowered_.kernel(); + const auto kernel = lowered_->kernel(); const auto& kernel_summary = kernel->summary(); // Calculate Dynamic Shared Memory Size // Add workspace for reduction and broadcast uint64_t reduction_broadcast_workspace = 0; const bool has_workspace = kernel_summary.has_block_reductions || - kernel_summary.has_grid_reductions || kernel_summary.has_block_broadcasts; + kernel_summary.has_grid_reductions || + kernel_summary.has_block_broadcasts || kernel_summary.has_grid_broadcasts; if (has_workspace && kernel_summary.largest_smem_data_type != DataType::Null) { // Not using nThreads here since it does not handle uninitialized value @@ -533,14 +527,14 @@ FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals( kir::ExpressionEvaluator& expr_eval) { FUSER_PERF_SCOPE("FusionExecutor::AllocGlobalVals"); GlobalBuffers global_buffers; - const auto kernel = lowered_.kernel(); - const auto& kernel_summary = lowered_.kernel()->summary(); + const auto kernel = lowered_->kernel(); + const auto& kernel_summary = lowered_->kernel()->summary(); for (auto alloc : kernel_summary.global_allocations) { TORCH_INTERNAL_ASSERT( - alloc->buffer()->isA(), + alloc->buffer()->isA(), "Cannot allocate global buffers that are not tensors."); - auto tv = alloc->buffer()->as(); - if (kernel->isOutput(tv)) { + auto tv = alloc->buffer()->as(); + if (tv->isFusionOutput()) { continue; } if (alloc->zeroInit()) { @@ -561,14 +555,14 @@ std::vector FusionExecutor::allocOutputs( kir::ExpressionEvaluator& expr_eval, const std::unordered_set& alias_indices) { FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs"); - const auto kernel = lowered_.kernel(); + const auto kernel = lowered_->kernel(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector outputs; for (const auto i : c10::irange(kernel->outputs().size())) { TORCH_INTERNAL_ASSERT( - kernel->outputs()[i]->isA(), + kernel->outputs()[i]->isA(), "Cannot allocate outputs that are not tensors."); - auto output = kernel->outputs()[i]->as(); + auto output = kernel->outputs()[i]->as(); if (alias_indices.count(i) == 0) { outputs.push_back( inferAndAllocOutput(output, expr_eval, options_, false)); @@ -581,7 +575,7 @@ std::vector FusionExecutor::allocOutputs( } void FusionExecutor::setUsedTVs() { - auto used_vals = fusion_.usedMathVals(); + auto used_vals = fusion_->usedMathVals(); auto used_tvs = ir_utils::filterByType(used_vals); used_tvs_.clear(); @@ -595,7 +589,7 @@ std::vector FusionExecutor::runFusion( const LaunchParams& launch_constraints, const c10::optional& opt_code) { FUSER_PERF_SCOPE("FusionExecutor::RunFusion"); - + TORCH_INTERNAL_ASSERT(compiled()); TORCH_INTERNAL_ASSERT( fusion_id_ > 0, "Cannot run fusion, it was not compiled."); TORCH_INTERNAL_ASSERT( @@ -607,11 +601,10 @@ std::vector FusionExecutor::runFusion( executor_entry = &executor_entry_lookup_[*opt_code]; } - FusionGuard fg(&fusion_); c10::DeviceGuard dg(options_.device); auto stream = at::cuda::getCurrentCUDAStream(); executor_utils::initializeCudaContext(); - + TORCH_INTERNAL_ASSERT(lowered_); LaunchParams launch_params; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector allocated_outputs = outputs; @@ -642,7 +635,7 @@ std::vector FusionExecutor::runFusion( } } else { TORCH_INTERNAL_ASSERT( - outputs.size() == fusion_.outputs().size(), + outputs.size() == fusion_->outputs().size(), __func__, " provided number of outputs does match fusion output"); } @@ -672,20 +665,35 @@ std::vector FusionExecutor::runFusion( // code path to take when either: // 1. no opt_code is provided or // 2. `executor_entry` is not initialized - executor_utils::validateKernelInputs(&fusion_, inputs, options_.device); + executor_utils::validateKernelInputs(fusion_, inputs, options_.device); if (!evaluator_precomputed_integers_) { evaluator_precomputed_integers_ = - std::make_unique(&fusion_, lowered_); + std::make_unique(lowered_->kernel()); } kir::ExpressionEvaluator expr_eval; - evaluator_precomputed_integers_->bindKernelInputs(inputs); + evaluator_precomputed_integers_->bindKernelInputs( + lowered_->kernel(), inputs); expr_eval.precomputedIntegers() = evaluator_precomputed_integers_.get(); launch_params = computeLaunchParams(launch_constraints, expr_eval, warp_size_); + // Recompile the kernel if the number of threads in the block has increased + if (launch_params.nThreads() > block_size_high_water_mark) { + const auto kernel = lowered_->kernel(); + const auto kernel_code = + codegen::generateCudaKernel(kernel, kernelName()); + const auto structured_code = getStructuredCode(kernel_code); + block_size_high_water_mark = launch_params.nThreads(); + compiled_kernel_ = executor_utils::nvrtcCompile( + structured_code, + (kernelNamespace() + "::" + kernelName()).c_str(), + fusion_id_, + block_size_high_water_mark); + } + if (kernel()->summary().has_cooperative_grid_reduction) { #ifndef __HIP_PLATFORM_HCC__ int num_blocks_per_SM = -1; @@ -716,16 +724,18 @@ std::vector FusionExecutor::runFusion( } executor_utils::validateVectorizedTensors( - &fusion_, inputs, outputs, lowered_, compileTimeDataCache(), expr_eval); - - auto& fusion = fusion_; + lowered_.get()->kernel(), + inputs, + outputs, + compileTimeDataCache(), + expr_eval); auto alias_indices_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::InputAliasIndices>( - compileTimeDataCache(), [&fusion]() { + compileTimeDataCache(), [&]() { return std::make_unique>>( - fusion.getInputAliasIndices()); + fusion_->getInputAliasIndices()); }); auto& alias_indices = alias_indices_entry.get(); @@ -736,9 +746,9 @@ std::vector FusionExecutor::runFusion( auto output_alias_indices_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::OutputAliasIndices>( - compileTimeDataCache(), [&fusion]() { + compileTimeDataCache(), [&]() { return std::make_unique>( - fusion.getOutputAliasIndices()); + fusion_->getOutputAliasIndices()); }); auto& output_alias_indices = output_alias_indices_entry.get(); @@ -753,7 +763,7 @@ std::vector FusionExecutor::runFusion( } else { // TODO: Update this as well; executor_utils::validateKernelOutputs( - &fusion_, allocated_outputs, options_.device); + fusion_, allocated_outputs, options_.device); } global_buffers = allocGlobalVals(expr_eval); @@ -802,7 +812,7 @@ std::vector FusionExecutor::runFusion( kernel_arguments.push(inputs); kernel_arguments.push(allocated_outputs); kernel_arguments.push(global_buffers.buffers); - if (lowered_.kernel()->summary().is_stochastic) { + if (lowered_->kernel()->summary().is_stochastic) { kernel_arguments.appendPhiloxRNGSeed(rand_offset); } } diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 523f2aa0e4b..40accbfb520 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -35,9 +35,9 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { void compileFusion( Fusion* fusion, - CompileOptions options = CompileOptions(), const at::ArrayRef& inputs = {}, - const LaunchParams& launch_constraints = LaunchParams()); + const LaunchParams& launch_constraints = LaunchParams(), + CompileOptions options = CompileOptions()); std::vector runFusion( const at::ArrayRef& inputs, @@ -55,7 +55,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { // function to query whether a `FusionExecutor` has a compiled kernel to // execute bool compiled() const { - return fusion_id_ != -1; + return fusion_id_ != -1 && lowered_; }; void evictCache(size_t cache_id) { @@ -85,7 +85,8 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { executor_utils::caching::ExecutorCompileTimeInfoCache; kir::Kernel* kernel() const { - return lowered_.kernel(); + TORCH_INTERNAL_ASSERT(lowered_); + return lowered_->kernel(); } //! Internal knob used for debugging/profiling only @@ -178,8 +179,6 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { } private: - Fusion fusion_; - CompileOptions options_; size_t max_device_smem = std::numeric_limits().max(); int warp_size_ = 0; @@ -192,7 +191,13 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { int fusion_id_ = -1; static int fusion_id_counter_; - GpuLower lowered_; + std::unique_ptr lowered_; + // Copy of lowered_->kernel() + Fusion* fusion_ = nullptr; + + // Track the block size this kernel was compiled with. If the block size + // increases, recompile to adjust maxregister count. + int64_t block_size_high_water_mark = 1; // lookup table to take short cut to retrieve recorded information in order to // launch kernels without re-inference parameters. diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp index 968570c1086..883fae207c5 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp @@ -1,4 +1,3 @@ -#include #include // Extract size and strides @@ -65,7 +64,7 @@ std::unique_ptr getTensorArg(int nDims) { false, "Tried to generate a tensor to run a generated kernel with ", nDims, - " dimensions, however it must be a size 0 to 8 dimensional tensor."); + " dimensions, however only 0 to 8 dimensional tensor are supported."); } return nullptr; } @@ -98,8 +97,6 @@ std::unique_ptr getTensorArg( } } -} // namespace - std::unique_ptr getTensorArg( c10::ScalarType dtype, int nDims, @@ -117,20 +114,73 @@ std::unique_ptr getTensorArg( return nullptr; } +} // namespace + // Push a tensor to the arguments void KernelArgumentHolder::push(const at::Tensor& tensor) { changed_ = true; - int nDims = tensor.ndimension(); - - c10::ScalarType dtype = tensor.scalar_type(); - std::unique_ptr tensor_arg = - getTensorArg(dtype, nDims, index_mode_); - tensor_arg->setPointer(tensor.data_ptr()); - for (const auto i : c10::irange(nDims)) { - tensor_arg->setSize(i, tensor.sizes()[i]); - tensor_arg->setStride(i, tensor.strides()[i]); + if (is_cpu_scalar(tensor)) { + switch (tensor.scalar_type()) { + case c10::ScalarType::Double: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Float: + arguments_.push_back( + std::make_unique>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Half: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::BFloat16: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Bool: + arguments_.push_back( + std::make_unique>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Long: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Int: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + default: + TORCH_CHECK( + false, + "Dtype: ", + tensor.scalar_type(), + " not currently supported in code generated kernels."); + } + } else { + int nDims = tensor.ndimension(); + + c10::ScalarType dtype = tensor.scalar_type(); + std::unique_ptr tensor_arg = + getTensorArg(dtype, nDims, index_mode_); + tensor_arg->setPointer(tensor.data_ptr()); + for (const auto i : c10::irange(nDims)) { + tensor_arg->setSize(i, tensor.sizes()[i]); + tensor_arg->setStride(i, tensor.strides()[i]); + } + arguments_.push_back(std::move(tensor_arg)); } - arguments_.push_back(std::move(tensor_arg)); } // Push a scalar or integer to the arguments diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h index d306683c43d..d457a69adb2 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h @@ -33,6 +33,7 @@ struct TensorArgCodegen { } }; +// 0-Dim GPU based tensor template struct TensorArgCodegen { T& operator[](nvfuser_index_t ind) { @@ -51,6 +52,17 @@ struct TensorArgCodegen { } }; +// Specialization for 0-dim case that's easy to pass in a CPU based tensor +// without memcpy +template +struct CpuScalarTensorCodegen { + T& operator[](int) { + return data; + }; + + T data; +}; + struct ArgAbstract { virtual ~ArgAbstract() = default; virtual void* arg() = 0; @@ -67,7 +79,7 @@ struct PhiloxCudaStateArg : public ArgAbstract { struct LongArg : public ArgAbstract { int64_t val_; - explicit LongArg(int64_t _val) : val_(_val){}; + explicit LongArg(int64_t _val) : val_(_val) {} // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) void* arg() { return &val_; @@ -76,7 +88,7 @@ struct LongArg : public ArgAbstract { struct DoubleArg : public ArgAbstract { double val_; - explicit DoubleArg(double _val) : val_(_val){}; + explicit DoubleArg(double _val) : val_(_val) {} // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) void* arg() { return &val_; @@ -85,7 +97,7 @@ struct DoubleArg : public ArgAbstract { struct BoolArg : public ArgAbstract { bool val_; - explicit BoolArg(bool _val) : val_(_val){}; + explicit BoolArg(bool _val) : val_(_val) {} // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) void* arg() { return &val_; @@ -119,9 +131,20 @@ struct TensorArg : public TensorArgAbstract { } }; -std::unique_ptr getTensorArg( - c10::ScalarType dtype, - int nDims); +template +struct CpuScalarTensorArg : public ArgAbstract { + CPU_TENSOR_TYPE instance_; + + CpuScalarTensorArg() = delete; + + explicit CpuScalarTensorArg(decltype(CPU_TENSOR_TYPE::data) _data) { + instance_.data = _data; + } + + void* arg() override { + return &instance_; + } +}; class KernelArgumentHolder { public: diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 13cdc29099e..5323036e5df 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -110,6 +109,16 @@ bool validateKernelArgTensor( return false; } + if (is_cpu_scalar(arg) && !param->as()->isCpuScalar()) { + msg << "Argument is CPU Scalar Tensor, but parameter is not.\n"; + return false; + } + + if (!is_cpu_scalar(arg) && !arg.is_cuda()) { + msg << "Argumnet is a CPU tensor which is not supported in fusions.\n"; + return false; + } + // Check the rank of the tensors. size_t arg_dim = arg.dim(); // Note: This requires current Fusion to be active. @@ -126,7 +135,7 @@ bool validateKernelArgTensor( return false; } - if (arg.device() != device) { + if (!is_cpu_scalar(arg) && arg.device() != device) { msg << "Argument is on device that is not compiled for." << "\n"; return false; @@ -339,6 +348,8 @@ void validateKernelOutputs( !mismatch, "Found one or more invalid arguments: ", msg.str()); } +namespace { + bool canVectorize(const IValue& aten_val, int word_size) { if (!aten_val.isTensor()) { return false; @@ -371,16 +382,18 @@ bool canVectorize(const IValue& aten_val, int word_size) { return true; } +// Returns true if a TV can be used with ParallelType::Vectorize. When +// input or output tensors are involved, the other version of +// canVectorize is used. bool canVectorize( - TensorView* fusion_tv, + TensorView* tv, int word_size, - GpuLower& lower, kir::ExpressionEvaluator& expr_eval) { IterDomain* last_root_dim = nullptr; - // TODO: Should this be rfactor instead of root?? - for (size_t i = fusion_tv->getRootDomain().size(); i > 0; i--) { - auto r_id = fusion_tv->getRootDomain()[i - 1]; - if (r_id->isReduction() || r_id->isBroadcast()) { + for (size_t i = tv->getRootDomain().size(); i > 0; i--) { + auto r_id = tv->getRootDomain()[i - 1]; + if (r_id->isReduction() || r_id->isTrivialReduction() || + r_id->isBroadcast()) { continue; } last_root_dim = r_id; @@ -391,8 +404,7 @@ bool canVectorize( return false; } - auto last_dim_size = - expr_eval.evaluate(lower.lowerValue(last_root_dim->extent())); + auto last_dim_size = expr_eval.evaluate(last_root_dim->extent()); if (!last_dim_size.has_value()) { return false; @@ -405,8 +417,6 @@ bool canVectorize( return true; } -namespace { - // Check if there's any split that is non-divisible and vectorized. If // found, Vectorize is illegal. void validateVectorizedSplits( @@ -418,12 +428,12 @@ void validateVectorizedSplits( TORCH_INTERNAL_ASSERT( input_extent.has_value(), "Could not check if a split with vectorization is divisible because the extent, ", - kir::toString(extent_factor.first), + extent_factor.first->toString(), ", is not possible to evaluate."); TORCH_INTERNAL_ASSERT( input_extent.has_value(), "Could not check if a split with vectorization is divisible because the split factor, ", - kir::toString(extent_factor.second), + extent_factor.second->toString(), ", is not possible to evaluate."); TORCH_INTERNAL_ASSERT( input_extent.value() % split_factor.value() == 0, @@ -435,16 +445,144 @@ void validateVectorizedSplits( } } +//! Returns the position information of vectorized input/output tensors +//! in the given fusion. +std::unique_ptr getVectorizedTensorValidationInfo( + Fusion* fusion) { + auto vectorized_tensor_info_ptr = + std::make_unique(); + auto& tv_to_vector_word_size = + vectorized_tensor_info_ptr->tv_to_vector_word_size; + auto& global_inp_misaligned_tv = + vectorized_tensor_info_ptr->global_inp_misaligned_tv; + auto& global_out_misaligned_tv = + vectorized_tensor_info_ptr->global_out_misaligned_tv; + + kir::ExpressionEvaluator expr_eval; + + // Find all vectorized tensors and their word size + for (auto expr : fusion->exprs()) { + if (!expr->isA() || + expr->as()->getUnaryOpType() != UnaryOpType::Set) { + continue; + } + auto uop = expr->as(); + if (!uop->out()->isA() || !uop->in()->isA()) { + continue; + } + auto out_tv = uop->out()->as(); + auto in_tv = uop->in()->as(); + IterDomain* vector_dim = nullptr; + for (auto id : out_tv->domain()->domain()) { + if (id->getParallelType() == ParallelType::Vectorize || + id->getParallelType() == ParallelType::MisalignedVectorize) { + TORCH_INTERNAL_ASSERT( + vector_dim == nullptr, + "Found multiple vectorized dimensions on tensor ", + out_tv); + vector_dim = id; + } + } + if (vector_dim == nullptr) { + continue; + } + auto vector_word_size = expr_eval.evaluate(vector_dim->extent()); + TORCH_INTERNAL_ASSERT( + vector_word_size.has_value(), + "Non constant vector dimension found in ", + out_tv); + + // The expression here must be a UnaryOp::Set, so checking either of the + // input or output tensor should be sufficient. When the output is a + // fusion output, check the tensor as its size information is available + // without using the expression evaluator. + auto tv_to_verify = out_tv->isFusionOutput() ? out_tv : in_tv; + tv_to_vector_word_size[tv_to_verify] = vector_word_size.value(); + + if (vector_dim->getParallelType() == ParallelType::MisalignedVectorize) { + TORCH_INTERNAL_ASSERT( + in_tv->isFusionInput() || out_tv->isFusionOutput(), + "MisalignedVectorize is assumed to be used with either input or output tensor"); + if (out_tv->getMemoryType() == MemoryType::Global && + in_tv->getMemoryType() == MemoryType::Local) { + global_out_misaligned_tv.insert(out_tv); + } else if ( + in_tv->getMemoryType() == MemoryType::Global && + out_tv->getMemoryType() == MemoryType::Local) { + global_inp_misaligned_tv.insert(in_tv); + } else { + TORCH_INTERNAL_ASSERT( + false, + "Unsupported memory configuration for misaligned vectorization."); + } + } + } + + // Check striding information on input and outputs as well as size information + // of all + auto& inp_misaligned_tensors_pos = + vectorized_tensor_info_ptr->inp_misaligned_tensors_pos; + auto& out_misaligned_tensors_pos = + vectorized_tensor_info_ptr->out_misaligned_tensors_pos; + auto& inp_pos_to_word_size_map_to_verify = + vectorized_tensor_info_ptr->inp_pos_to_word_size_map_to_verify; + auto& out_pos_to_word_size_map_to_verify = + vectorized_tensor_info_ptr->out_pos_to_word_size_map_to_verify; + auto& intermediate_tv_to_word_size_map_to_verify = + vectorized_tensor_info_ptr->intermediate_tv_to_word_size_map_to_verify; + + for (auto entry : tv_to_vector_word_size) { + auto tv = entry.first; + auto word_size = entry.second; + if (tv->isFusionInput()) { + auto inp_it = + std::find(fusion->inputs().begin(), fusion->inputs().end(), tv); + TORCH_INTERNAL_ASSERT( + inp_it != fusion->inputs().end(), + "Could not find ", + tv, + " in fusion inputs."); + auto inp_pos = std::distance(fusion->inputs().begin(), inp_it); + + if (global_inp_misaligned_tv.find(tv) != global_inp_misaligned_tv.end()) { + inp_misaligned_tensors_pos.emplace_back(inp_pos); + } else { + // Shouldn't visit same pos twice here, assert ? + inp_pos_to_word_size_map_to_verify[inp_pos] = word_size; + } + } else if (tv->isFusionOutput()) { + auto out_it = + std::find(fusion->outputs().begin(), fusion->outputs().end(), tv); + TORCH_INTERNAL_ASSERT( + out_it != fusion->outputs().end(), + "Could not find ", + tv, + " in provided fusion outputs."); + auto out_pos = std::distance(fusion->outputs().begin(), out_it); + + if (global_out_misaligned_tv.find(tv) != global_out_misaligned_tv.end()) { + out_misaligned_tensors_pos.emplace_back(out_pos); + } else { + out_pos_to_word_size_map_to_verify[out_pos] = word_size; + } + } else { + // Intermediate tensors. Note that this must be Vectorize as + // MisalignedVectorize is only supported for inputs and outputs. + intermediate_tv_to_word_size_map_to_verify[tv] = word_size; + } + } + + return vectorized_tensor_info_ptr; +} } // namespace // Misaligned vectorization check. Currently misaligned vectorization is limited // to global-register and register-global load/store patterns. However, this // could be improved to include shared memory. void validateVectorizedTensors( - Fusion* fusion, + kir::Kernel* kernel, const at::ArrayRef& inputs, const std::vector& outputs, - GpuLower& lower, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval) { FUSER_PERF_SCOPE("FusionExecutor::validateVectorizedTensors"); @@ -452,9 +590,8 @@ void validateVectorizedTensors( auto tensor_vectorization_validation_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::VectorizedTensorValidation>( - data_cache, [fusion, &lower]() { - return executor_utils::getVectorizedTensorValidationInfo( - fusion, lower); + data_cache, [kernel]() { + return executor_utils::getVectorizedTensorValidationInfo(kernel); }); // Validate all the canVectorizes: @@ -463,7 +600,7 @@ void validateVectorizedTensors( TORCH_INTERNAL_ASSERT( canVectorize(inputs[it.first], it.second), "Error vectorizing, ", - fusion->inputs()[it.first], + kernel->inputs()[it.first], " as input provided does not allowed vectorization by word size, ", it.second); } @@ -474,12 +611,24 @@ void validateVectorizedTensors( TORCH_INTERNAL_ASSERT( canVectorize(outputs[it.first], it.second), "Error vectorizing, ", - fusion->outputs()[it.first], + kernel->outputs()[it.first], " as output provided does not allowed vectorization by word size, ", it.second); } } + for (auto it : tensor_vectorization_validation_entry.get() + .intermediate_tv_to_word_size_map_to_verify) { + auto tv = it.first; + auto vec_width = it.second; + TORCH_INTERNAL_ASSERT( + canVectorize(tv, vec_width, expr_eval), + "Error vectorizing, ", + tv->toString(), + " as the extent of the vectorized axis does not allowed vectorization by word size, ", + vec_width); + } + std::vector inp_misaligned_tensors; std::vector out_misaligned_tensors; @@ -511,7 +660,7 @@ void validateVectorizedTensors( out_misaligned_tensors), "All global tensors must have the same stride for misaligned vectorization."); - validateVectorizedSplits(lower.kernel(), expr_eval); + validateVectorizedSplits(kernel, expr_eval); } kir::ExpressionEvaluator bindKernelInputs( @@ -530,7 +679,7 @@ kir::ExpressionEvaluator bindKernelInputs( for (const auto i : c10::irange(inputs.size())) { const auto input = inputs[i]; - if (auto tensor_input = dynamic_cast(input)) { + if (auto tensor_input = dynamic_cast(input)) { TORCH_INTERNAL_ASSERT( aten_inputs[i].isTensor(), "Something went wrong configuring launch. Inputs no longer match at index:", @@ -538,7 +687,7 @@ kir::ExpressionEvaluator bindKernelInputs( const auto aten_tensor = aten_inputs[i].toTensor(); const auto root_domain = - kir::TensorDomain::noReductions(tensor_input->domain()->rootDomain()); + TensorDomain::noReductions(tensor_input->domain()->getRootDomain()); TORCH_INTERNAL_ASSERT( aten_tensor.ndimension() == static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs no longer match."); @@ -553,7 +702,7 @@ kir::ExpressionEvaluator bindKernelInputs( TORCH_CHECK( *prev_value == value, "Attempting to bind ", - kir::toString(extent), + extent->toString(), " to ", value, "but it's already set to ", @@ -561,7 +710,7 @@ kir::ExpressionEvaluator bindKernelInputs( should_bind = false; } } - if (should_bind && !extent->isConst()) { + if (should_bind && !extent->isConstScalar()) { expr_eval.bind(extent, value); } } @@ -697,24 +846,19 @@ NvrtcFunction nvrtcCompile( "--std=c++14", compute.c_str(), "-default-device"}; #endif - const char* disable_fastmath = getenv("PYTORCH_NVFUSER_DISABLE_FASTMATH"); - if (!disable_fastmath || (atoi(disable_fastmath) == 0)) { - args.push_back("--use_fast_math"); - } else { - TORCH_WARN_ONCE( - "fast math disabled in nvfuser, try set `PYTORCH_NVFUSER_DISABLE_FASTMATH=0`"); - } - const char* disable_fma = getenv("PYTORCH_NVFUSER_DISABLE_FMA"); - // int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0; - if (disable_fma && atoi(disable_fma)) { #ifdef __HIP_PLATFORM_HCC__ + if (disable_fma && atoi(disable_fma)) { TORCH_WARN_ONCE( "PYTORCH_CUDA_FUSER_DISABLE_FMA is not supported on ROCm, ignoring"); + } #else + if (disable_fma && atoi(disable_fma)) { args.push_back("--fmad=false"); -#endif + } else { + args.push_back("--fmad=true"); } +#endif #ifndef NDEBUG // Add line info to generated kernels @@ -1037,7 +1181,7 @@ template class ExecutorCompileTimeEntry; } // namespace caching std::vector getParallelBindingsIterDomains( - GpuLower& lower, + GpuLower* lower, const std::vector& used_tvs) { std::vector parallel_ids; for (auto tv : used_tvs) { @@ -1047,7 +1191,7 @@ std::vector getParallelBindingsIterDomains( // Want to keep the broadcast dimensions if they are not resolved // TODO: piping down the parallel dimension map here would // be helpful - auto& parallel_map = lower.caParallelMap(); + auto& parallel_map = lower->caParallelMap(); if (parallel_map.getConcreteMappedID(id) == id) { parallel_ids.push_back(id); } @@ -1062,39 +1206,41 @@ std::vector getParallelBindingsIterDomains( return parallel_ids; } +namespace { + void insertParallelExtent( - GpuLower& lower, IterDomain* binding_id, const std::unique_ptr& parallel_iter_extents_ptr) { - auto kir_extent = lower.lowerValue(binding_id->extent()); + auto extent = binding_id->extent(); const auto it = parallel_iter_extents_ptr->find(binding_id->getParallelType()); if (it != parallel_iter_extents_ptr->end()) { - it->second.push_back(kir_extent); + it->second.push_back(extent); } else { parallel_iter_extents_ptr->operator[](binding_id->getParallelType()) = { - kir_extent}; + extent}; } } +} // namespace + std::unique_ptr getParallelIterExtents( - GpuLower& lower, std::vector& parallel_binding_ids) { auto parallel_iter_extents_ptr = std::make_unique(); for (auto id : parallel_binding_ids) { - insertParallelExtent(lower, id, parallel_iter_extents_ptr); + insertParallelExtent(id, parallel_iter_extents_ptr); } return parallel_iter_extents_ptr; } std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower& lower, + GpuLower* lower, std::vector& parallel_binding_ids) { auto parallel_iter_extents_ptr = std::make_unique(); - auto& parallel_map = lower.caParallelMap(); + auto& parallel_map = lower->caParallelMap(); std::vector mapped; - bool is_tidx_warp_padded = lower.getWarpPaddedParallelInfo().is_tidx_padded; + bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded; for (auto id : parallel_binding_ids) { if (std::any_of( @@ -1109,7 +1255,7 @@ std::unique_ptr getSimplifiedParallelIterExtents( } insertParallelExtent( - lower, parallel_map.getConcreteMappedID(id), parallel_iter_extents_ptr); + parallel_map.getConcreteMappedID(id), parallel_iter_extents_ptr); mapped.push_back(id); } @@ -1117,7 +1263,7 @@ std::unique_ptr getSimplifiedParallelIterExtents( } std::unique_ptr getWarpPaddedExtentsInfo( - GpuLower& lower, + kir::Kernel* kernel, std::vector& parallel_binding_ids) { auto warp_padded_extent_info_ptr = std::make_unique(); @@ -1125,7 +1271,6 @@ std::unique_ptr getWarpPaddedExtentsInfo( warp_padded_extent_info_ptr->warp_padded_extent_set; auto& warp_padded_constant = warp_padded_extent_info_ptr->warp_padded_constant; - auto kernel = lower.kernel(); bool has_warp_reduction = kernel->getWarpPaddedParallelInfo().has_warp_reduction; @@ -1135,11 +1280,11 @@ std::unique_ptr getWarpPaddedExtentsInfo( if (has_warp_reduction) { if (id->hasPaddingToMultipleOfWarp() || kernel->isParallelTypePadded(id->getParallelType())) { - auto kir_extent = lower.lowerValue(id->extent()); - warp_padded_extent_set.insert(kir_extent); + auto extent = id->extent(); + warp_padded_extent_set.insert(extent); auto padded_value = id->getMaybeSizeAfterPadding(); if (padded_value.has_value()) { - warp_padded_constant[kir_extent] = padded_value.value(); + warp_padded_constant[extent] = padded_value.value(); } } } @@ -1147,122 +1292,6 @@ std::unique_ptr getWarpPaddedExtentsInfo( return warp_padded_extent_info_ptr; } -std::unique_ptr getVectorizedTensorValidationInfo( - Fusion* fusion, - GpuLower& lower) { - auto vectorized_tensor_info_ptr = - std::make_unique(); - auto& tv_to_vector_word_size = - vectorized_tensor_info_ptr->tv_to_vector_word_size; - auto& global_inp_misaligned_tv = - vectorized_tensor_info_ptr->global_inp_misaligned_tv; - auto& global_out_misaligned_tv = - vectorized_tensor_info_ptr->global_out_misaligned_tv; - - kir::ExpressionEvaluator expr_eval; - - // Find all vectorized tensors and their word size - for (auto expr : fusion->exprs()) { - if (!expr->isA() || - expr->as()->getUnaryOpType() != UnaryOpType::Set) { - continue; - } - auto uop = expr->as(); - if (!uop->out()->isA() || !uop->in()->isA()) { - continue; - } - auto out_tv = uop->out()->as(); - auto in_tv = uop->in()->as(); - IterDomain* vector_dim = nullptr; - for (auto id : out_tv->domain()->domain()) { - if (id->getParallelType() == ParallelType::Vectorize || - id->getParallelType() == ParallelType::MisalignedVectorize) { - TORCH_INTERNAL_ASSERT( - vector_dim == nullptr, - "Found multiple vectorized dimensions on tensor ", - out_tv); - vector_dim = id; - } - } - if (vector_dim == nullptr) { - continue; - } - auto vector_word_size = - expr_eval.evaluate(lower.lowerValue(vector_dim->extent())); - TORCH_INTERNAL_ASSERT( - vector_word_size.has_value(), - "Non constant vector dimension found in ", - out_tv); - tv_to_vector_word_size[out_tv] = vector_word_size.value(); - tv_to_vector_word_size[in_tv] = vector_word_size.value(); - - if (vector_dim->getParallelType() == ParallelType::MisalignedVectorize) { - if (out_tv->getMemoryType() == MemoryType::Global && - in_tv->getMemoryType() == MemoryType::Local) { - global_out_misaligned_tv.insert(out_tv); - } else if ( - in_tv->getMemoryType() == MemoryType::Global && - out_tv->getMemoryType() == MemoryType::Local) { - global_inp_misaligned_tv.insert(in_tv); - } else { - TORCH_INTERNAL_ASSERT( - false, - "Unsupported memory configuration for misaligned vectorization."); - } - } - } - - // Check striding information on input and outputs as well as size information - // of all - auto& inp_misaligned_tensors_pos = - vectorized_tensor_info_ptr->inp_misaligned_tensors_pos; - auto& out_misaligned_tensors_pos = - vectorized_tensor_info_ptr->out_misaligned_tensors_pos; - auto& inp_pos_to_word_size_map_to_verify = - vectorized_tensor_info_ptr->inp_pos_to_word_size_map_to_verify; - auto& out_pos_to_word_size_map_to_verify = - vectorized_tensor_info_ptr->out_pos_to_word_size_map_to_verify; - - for (auto entry : tv_to_vector_word_size) { - auto tv = entry.first; - auto word_size = entry.second; - if (tv->isFusionInput()) { - auto inp_it = - std::find(fusion->inputs().begin(), fusion->inputs().end(), tv); - TORCH_INTERNAL_ASSERT( - inp_it != fusion->inputs().end(), - "Could not find ", - tv, - " in fusion inputs."); - auto inp_pos = std::distance(fusion->inputs().begin(), inp_it); - - if (global_inp_misaligned_tv.find(tv) != global_inp_misaligned_tv.end()) { - inp_misaligned_tensors_pos.emplace_back(inp_pos); - } else { - // Shouldn't visit same pos twice here, assert ? - inp_pos_to_word_size_map_to_verify[inp_pos] = word_size; - } - } else if (tv->isFusionOutput()) { - auto out_it = - std::find(fusion->outputs().begin(), fusion->outputs().end(), tv); - TORCH_INTERNAL_ASSERT( - out_it != fusion->outputs().end(), - "Could not find ", - tv, - " in provided fusion outputs."); - auto out_pos = std::distance(fusion->outputs().begin(), out_it); - - if (global_out_misaligned_tv.find(tv) != global_out_misaligned_tv.end()) { - out_misaligned_tensors_pos.emplace_back(out_pos); - } else { - out_pos_to_word_size_map_to_verify[out_pos] = word_size; - } - } - } - - return vectorized_tensor_info_ptr; -} - } // namespace executor_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.h b/torch/csrc/jit/codegen/cuda/executor_utils.h index d851be48991..93deec6343f 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.h +++ b/torch/csrc/jit/codegen/cuda/executor_utils.h @@ -28,28 +28,16 @@ namespace executor_utils { // Include all the functions we might need in generated code std::string kernelPreamble(); -// TODO(kir): rewrite in terms of Kernel inputs void validateKernelInputs( Fusion* fusion, const at::ArrayRef& inputs, const c10::Device& device); -// TODO(kir): rewrite in terms of Kernel outputs void validateKernelOutputs( Fusion* fusion, const std::vector& outputs, const c10::Device& device); -// Returns if vectorizing the aten value by word size is possible -bool canVectorize(const IValue& aten_val, int word_size); - -// Returns if vectorizing the aten value by word size is possible -bool canVectorize( - TensorView* fusion_tv, - int word_size, - GpuLower& lower, - kir::ExpressionEvaluator& expr_eval); - //! Bind kernel input values to runtime values kir::ExpressionEvaluator bindKernelInputs( const at::ArrayRef& aten_inputs, @@ -112,7 +100,7 @@ class ParallelBindingIterDomains { class ParallelIterExtentMap { public: using DataType = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; static const CompileTimeEntryType EntryType = CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; }; @@ -133,7 +121,7 @@ class ParallelIterExtentMap { class SimplifiedParallelIterExtentMap { public: using DataType = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; static const CompileTimeEntryType EntryType = CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; }; @@ -141,8 +129,8 @@ class SimplifiedParallelIterExtentMap { //! WarpPaddedExtentsInfo: //! Auxiliary data type for entry class WarpPaddedParallelExtents struct WarpPaddedExtentsInfo { - std::unordered_set warp_padded_extent_set; - std::unordered_map warp_padded_constant; + std::unordered_set warp_padded_extent_set; + std::unordered_map warp_padded_constant; }; //! Compile-time info to be cached in each FusionExecutor: @@ -166,6 +154,8 @@ struct VectorizedTensorInfo { std::vector out_misaligned_tensors_pos; std::unordered_map inp_pos_to_word_size_map_to_verify; std::unordered_map out_pos_to_word_size_map_to_verify; + std::unordered_map + intermediate_tv_to_word_size_map_to_verify; }; //! Compile-time info to be cached in each FusionExecutor: @@ -284,42 +274,33 @@ class ExecutorCompileTimeEntry { //! Returns the vector of tensorviews that will be used to bind parallel //! dimensions. std::vector getParallelBindingsIterDomains( - GpuLower& lower, + GpuLower* lower, const std::vector& used_tvs); using ParallelExtentMap = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; //! Returns the extents of all parallel binding iterdomains corresponding //! to each parallel type. std::unique_ptr getParallelIterExtents( - GpuLower& lower, std::vector& parallel_binding_ids); //! Returns the simplified set of extents necessary for launch parameter //! binding. std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower& lower, + GpuLower* lower, std::vector& parallel_binding_ids); //! Returns the symbolic or constant extetns of warp padded parallel //! iterdomains in the given vector. std::unique_ptr getWarpPaddedExtentsInfo( - GpuLower& lower, + kir::Kernel* lower, std::vector& parallel_binding_ids); -//! Returns the position information of vectorized input/output tensors -//! in the given fusion. -std::unique_ptr getVectorizedTensorValidationInfo( - Fusion* fusion, - GpuLower& lower); - -// TODO(kir): rewrite in terms of Kernel tensors void validateVectorizedTensors( - Fusion* fusion, + kir::Kernel* kernel, const at::ArrayRef& inputs, const std::vector& outputs, - GpuLower& lower, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval); diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index ced4b59a783..5630743b6f6 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index d9d71e53c41..be686c0d943 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -8,10 +8,9 @@ #include #include #include +#include #include -#include - namespace torch { namespace jit { namespace fuser { @@ -37,13 +36,7 @@ void swap(Fusion& a, Fusion& b) noexcept { using std::swap; - // Swap the content - swap(a.val_set_, b.val_set_); - swap(a.expr_set_, b.expr_set_); - swap(a.val_deque_, b.val_deque_); - - swap(a.val_type_name_map_, b.val_type_name_map_); - swap(a.expr_name_counter_, b.expr_name_counter_); + swap(static_cast(a), static_cast(b)); swap(a.inputs_, b.inputs_); swap(a.outputs_, b.outputs_); @@ -51,27 +44,6 @@ void swap(Fusion& a, Fusion& b) noexcept { swap(a.io_alias_, b.io_alias_); swap(a.permuted_input_map_, b.permuted_input_map_); swap(a.permuted_output_map_, b.permuted_output_map_); - - // Fixup the Statement::fusion_ links for a - for (auto val : a.val_set_) { - val->fusion_ = &a; - } - for (auto expr : a.expr_set_) { - expr->fusion_ = &a; - } - - // Fixup the Statement::fusion_ links for b - for (auto val : b.val_set_) { - val->fusion_ = &b; - } - for (auto expr : b.expr_set_) { - expr->fusion_ = &b; - } -} - -Fusion::Fusion(const Fusion& other) { - FUSER_PERF_SCOPE("Fusion copy"); - Fusion::copy(&other, this); } std::unique_ptr Fusion::segment( @@ -82,28 +54,13 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - IrCloner ir_cloner(to); + auto ir_cloner = IrContainer::copy(from, to); - for (auto val : from->val_set_) { - to->val_set_.insert(ir_cloner.clone(val)); - } - - for (auto expr : from->expr_set_) { - to->expr_set_.insert(ir_cloner.clone(expr)); - } - - for (auto val : from->val_deque_) { - to->val_deque_.push_back(ir_cloner.clone(val)); - } - - for (auto val : from->val_set_) { + for (auto val : from->vals_) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); } - to->val_type_name_map_ = from->val_type_name_map_; - to->expr_name_counter_ = from->expr_name_counter_; - to->inputs_ = ir_cloner.clone(from->inputs_); to->outputs_ = ir_cloner.clone(from->outputs_); @@ -117,9 +74,22 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->permuted_input_map_ = from->permuted_input_map_; to->permuted_output_map_ = from->permuted_output_map_; + to->all_tv_uses_valid_ = from->all_tv_uses_valid_; + // This should never be true on copy, but copying for completeness. + to->is_during_update_uses_ = from->is_during_update_uses_; + return ir_cloner; } +// Clang tidy complains when using default constructor for IrContainer instead +// of copy constructor. Fusion::copy has a call to IrContainer::copy, so it's +// redundant to use the IrContainer copy constructor, but it is harmless since +// Fusion::copy starts by calling clear(). +Fusion::Fusion(const Fusion& other) : IrContainer(other) { + FUSER_PERF_SCOPE("Fusion copy"); + Fusion::copy(&other, this); +} + Fusion::Fusion(Fusion&& other) noexcept { FUSER_PERF_SCOPE("Fusion move"); swap(*this, other); @@ -147,36 +117,22 @@ Fusion::~Fusion() { void Fusion::clear() noexcept { FUSER_PERF_SCOPE("Fusion clear"); - // Free the owned values - for (auto ptr : val_set_) { - delete ptr; - } - - // Free the owned expressions - for (auto ptr : expr_set_) { - delete ptr; - } - - val_set_.clear(); - val_deque_.clear(); - expr_set_.clear(); - - for (auto& kv : val_type_name_map_) { - kv.second = 0; - } - - expr_name_counter_ = 0; + IrContainer::clear(); inputs_.clear(); outputs_.clear(); io_alias_.clear(); + permuted_input_map_.clear(); permuted_output_map_.clear(); + + all_tv_uses_valid_ = false; + is_during_update_uses_ = false; } void Fusion::removeExpr(Expr* expr) { - assertInFusion(expr, "Cannot remove expr "); + assertInContainer(expr, "Cannot remove expr "); // If we hit this error too frequently, we could lighten the restrictions so // that removing something that doesn't exist simply does nothing. For now, // we're going with the strictest model which errors. @@ -194,13 +150,11 @@ void Fusion::removeExpr(Expr* expr) { } } - expr_set_.erase(expr); - - delete expr; + IrContainer::removeExpr(expr); } void Fusion::removeVal(Val* val) { - assertInFusion(val, "Cannot remove val "); + assertInContainer(val, "Cannot remove val "); TORCH_CHECK( !val->isFusionInput(), @@ -213,22 +167,14 @@ void Fusion::removeVal(Val* val) { if (orig != nullptr) removeExpr(val->definition()); - for (Expr* use : unordered_uses(val)) + for (Expr* use : unordered_uses(val)) { removeExpr(use); - - val_set_.erase(val); - - for (auto it = val_deque_.begin(); it != val_deque_.end(); it++) - if (*it == val) { - val_deque_.erase(it); - break; - } - - delete val; + } + IrContainer::removeVal(val); } void Fusion::addInput(Val* input) { - assertInFusion(input, "Cannot register input "); + assertInContainer(input, "Cannot register input "); if (input->getValType().value() == ValType::TensorView) { auto tv = input->as(); @@ -242,7 +188,7 @@ void Fusion::addInput(Val* input) { } void Fusion::addOutput(Val* output) { - assertInFusion(output, "Cannot register output "); + assertInContainer(output, "Cannot register output "); if (output->getValType().value() == ValType::TensorView) { auto tv = output->as(); tv->setMemoryType(MemoryType::Global); @@ -307,27 +253,8 @@ void Fusion::replaceOutput(Val* output, Val* replacement) { } } -bool Fusion::inFusion(const Statement* stmt) const { - bool in_fusion = stmt->fusion() == this; - Statement* nonconst_stmt = const_cast(stmt); // NOLINT - - if (stmt->isExpr()) { - in_fusion &= expr_set_.find(nonconst_stmt->as()) != expr_set_.end(); - } - if (stmt->isVal()) { - in_fusion &= val_set_.find(nonconst_stmt->as()) != val_set_.end(); - } - - return in_fusion; -} - -void Fusion::assertInFusion(const Statement* stmt, const std::string& msg) - const { - TORCH_CHECK(inFusion(stmt), msg, " it was not found in the active fusion."); -} - std::vector Fusion::exprs() { - return ExprSort::getExprs(this); + return StmtSort::getExprs(this); } std::vector Fusion::inputsOf(Val* val) { @@ -341,12 +268,24 @@ void Fusion::validateInputs() { all_inputs.insert(input); } } + + std::unordered_set input_dims; + auto inp_tvs = ir_utils::filterByType(inputs()); + for (auto tv : inp_tvs) { + for (auto id : tv->getMaybeRFactorDomain()) { + input_dims.emplace(id->extent()); + } + } for (Val* input : all_inputs) { if (!input->isConstScalar()) { TORCH_CHECK( - hasInput(input) || inFusion(input), + input->isFusionInput() || + // TODO: Switch: + inContainer(input), + // to: input_dims.find(input) != input_dims.end(), + // https://github.com/csarofeen/pytorch/issues/1365 "Could not figure out how ", - input, + input->toString(), " is generated, however it was not specified as an input."); } } @@ -367,6 +306,10 @@ void Fusion::print() { void Fusion::printKernel() { FUSER_PERF_SCOPE("Fusion::printKernel"); + TORCH_INTERNAL_ASSERT( + !this->isA(), + "Cannot \"print kernel\" of a kernel container. ", + "This would require lowering during lowering."); std::cout << codegen::generateCudaKernel(GpuLower(this).kernel()); } @@ -394,7 +337,7 @@ void Fusion::printMath(bool from_outputs_only) { leaf_vals.push_back(val); } } - exprs_for_print = ExprSort::getExprs(this, leaf_vals); + exprs_for_print = StmtSort::getExprs(this, leaf_vals); } std::cout << "\n%kernel_math {\n"; @@ -412,33 +355,36 @@ void Fusion::printTransforms() { t_exprs.handle(this); } -StmtNameType Fusion::registerVal(Val* val) { +void Fusion::registerVal(Val* val) { + if (inContainer(val)) { + return; + } + if (val->fusion()) { - if (val->fusion() != this) { - TORCH_CHECK(false, val, " was not found in the active fusion."); - } - if (inFusion(val)) { - return val->name(); - } + TORCH_CHECK( + val->fusion() == this, val, " was not found in the active fusion."); } - val_set_.emplace(val); - val_deque_.push_back(val); - return getValName(*(val->getValType())); + IrContainer::registerVal(val); } -StmtNameType Fusion::registerExpr(Expr* expr) { +void Fusion::registerExpr(Expr* expr) { + if (inContainer(expr)) { + return; + } + if (expr->fusion()) { - if (expr->fusion() != this) { - TORCH_CHECK(false, expr, " was not found in the active fusion."); - } - if (inFusion(expr)) { - return expr->name(); - } + TORCH_CHECK( + expr->fusion() == this, expr, " was not found in the active fusion."); } + IrContainer::registerExpr(expr); + + bool has_tv = false; + for (Val* input : expr->inputs()) { - assertInFusion(input, "Input to expr is invalid, "); + has_tv = has_tv || input->isA(); + assertInContainer(input, "Input to expr is invalid, "); auto uses_copy = input->uses(); if (std::find(uses_copy.begin(), uses_copy.end(), expr) == uses_copy.end()) { @@ -447,34 +393,25 @@ StmtNameType Fusion::registerExpr(Expr* expr) { } } + // Kernel is the only container type that is non-ssa. This is mainly (maybe + // only) because of initialization expressions which would overwrite tensor + // view definitions. + bool is_ssa = !this->isA(); + for (Val* output : expr->outputs()) { - assertInFusion(output, "Output to expr is invalid, "); - if (output->definition() != nullptr) { + has_tv = has_tv || output->isA(); + assertInContainer(output, "Output to expr is invalid, "); + if (output->definition() != nullptr && is_ssa) { removeExpr(output->definition()); } - output->setDefinition(expr); + if (is_ssa || (!is_ssa && output->definition() == nullptr)) { + output->setDefinition(expr); + } } - expr_set_.emplace(expr); - - resetTvUses(); - return getExprName(); -} - -StmtNameType Fusion::registerStatement(Statement* stmt) { - if (inFusion(stmt)) - return stmt->name(); - - if (stmt->isVal()) { - return registerVal(stmt->as()); - } else if (stmt->isExpr()) { - return registerExpr(stmt->as()); + if (has_tv) { + resetTvUses(); } - - TORCH_INTERNAL_ASSERT( - false, - "Could not register statement as Fusion could not recognize its type."); - return kInvalidStmName; } void Fusion::resetTvUses() { @@ -484,8 +421,8 @@ void Fusion::resetTvUses() { // getExprs only uses definition, so even if we've modified uses already to // remove dead exprs, this could reinsert them. getExprs is also boundeds by // inputs as registered inputs will return nullptr as their definition. - const auto all_tvs = ir_utils::filterByType(val_set_); - const auto used_exprs = ExprSort::getExprs(this); + const auto all_tvs = ir_utils::filterByType(vals_); + const auto used_exprs = StmtSort::getExprs(this); for (auto tv : all_tvs) { tv->setUses({}); @@ -507,14 +444,6 @@ void Fusion::resetTvUses() { is_during_update_uses_ = false; } -const std::unordered_set& Fusion::vals() const noexcept { - return val_set_; -} - -const std::deque& Fusion::deterministic_vals() const noexcept { - return val_deque_; -} - std::vector Fusion::usedMathVals() { // Note that using fusion->inputs() as the argument for the first // parameter of getAllValsBetween does not grab all used vals as @@ -553,37 +482,15 @@ std::vector Fusion::usedMathVals() { return used_math_vals; } -const std::unordered_set& Fusion::unordered_exprs() const noexcept { - return expr_set_; -} - std::unordered_set Fusion::unordered_uses(Val* val) const { return std::unordered_set(val->uses().begin(), val->uses().end()); } Expr* Fusion::definition(const Val* val) const { - assertInFusion(val, "Cannot detect the definition of val, "); + assertInContainer(val, "Cannot detect the definition of val, "); return val->definition(); } -bool Fusion::hasInput(const Val* val) const { - assertInFusion(val, "Cannot check if val is an input, "); - return val->isFusionInput(); -} - -bool Fusion::hasOutput(const Val* val) const { - assertInFusion(val, "Cannot check if val is an output, "); - return val->isFusionOutput(); -} - -StmtNameType Fusion::getValName(ValType vtype) { - return val_type_name_map_[vtype]++; -} - -StmtNameType Fusion::getExprName() { - return expr_name_counter_++; -} - // Indicate to kernel to set itself up to generate random numbers bool Fusion::isStochastic() { for (auto expr : exprs()) @@ -593,28 +500,6 @@ bool Fusion::isStochastic() { return false; } -bool Fusion::hasReduction() { - FUSER_PERF_SCOPE("Fusion::hasReduction"); - - for (auto expr : exprs()) - for (auto out : expr->outputs()) - if (out->getValType() == ValType::TensorView) - if (out->as()->hasReduction()) - return true; - - return false; -} - -bool Fusion::hasWelford() { - FUSER_PERF_SCOPE("Fusion::hasWelford"); - for (auto expr : exprs()) { - if (expr->isA()) { - return true; - } - } - return false; -} - std::vector Fusion::getTerminatingOutputs() { FUSER_PERF_SCOPE("getTerminatingOutputs"); diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index c892bd8171c..2e76e00896b 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -1,10 +1,11 @@ #pragma once #include +#include #include -#include #include +#include #include #include @@ -69,14 +70,14 @@ class TORCH_CUDA_CU_API FusionGuard { //! Fusion is mutable but unique. Nodes cannot be copied in any way from one //! Fusion to another. If anything like that is desired, it would require -//! duplicating all associated values and exprs. Fusion is considered to SSA, +//! duplicating all associated values and exprs. Fusion is considered to be SSA, //! though this could also change in the future if there is a good reason to do //! so. //! //! The Fusion owns the whole IR graph (Vals and Exprs) //! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API Fusion final { +class TORCH_CUDA_CU_API Fusion : public IrContainer { typedef std::unordered_map> PermutationMap; public: @@ -96,45 +97,30 @@ class TORCH_CUDA_CU_API Fusion final { //! Break dependency chains associated with Expr, remove references to expr //! delete expr - void removeExpr(Expr* expr); + void removeExpr(Expr* expr) override; //! Completely remove val from the fusion, break all dependencies associated //! with it - void removeVal(Val* val); + void removeVal(Val* val) override; //! Register input as an input of the fusion - // TODO: Rename to register void addInput(Val* input); //! Register output as an output of the fusion - // TODO: Rename to register void addOutput(Val* output); //! Register output as an output of the fusion - // TODO: Rename to register void addOutput(WelfordResult& output); //! Deregister input as an input of the fusion - // TODO: Rename to register void removeInput(Val* input); //! Deregister output as an output of the fusion - // TODO: Rename to register void removeOutput(Val* output); //! Replace output with another value void replaceOutput(Val* output, Val* replacement); - //! Clear Expr's from TV uses that are not required to produce outputs from - //! inputs - void resetTvUses(); - - //! Check if stmt is properly registered with this fusion - bool inFusion(const Statement* stmt) const; - - //! Throw an error if stmt is not in this fusion - void assertInFusion(const Statement* stmt, const std::string& msg = "") const; - //! Assert that all leaves found from outputs are registered as an input void validateInputs(); @@ -151,17 +137,6 @@ class TORCH_CUDA_CU_API Fusion final { //! Lower the fusion and print a kernel void printKernel(); - //! Register the Val with this fusion - StmtNameType registerVal(Val* val); - - //! Register expr with this fusion. - //! When we register an expression, we want to update the dependency tracking - //! of Vals. We add expr to our general expr_set_, - StmtNameType registerExpr(Expr* expr); - - //! Register stmt with this fusion - StmtNameType registerStatement(Statement* stmt); - //! Return a list of topologically sorted expressions. This only includes //! exprs required to genereate registered outputs. std::vector exprs(); @@ -169,12 +144,6 @@ class TORCH_CUDA_CU_API Fusion final { //! Return a vector of fusion inputs that feed this Val std::vector inputsOf(Val* val); - //! Return the set of Vals registered with this fusion - const std::unordered_set& vals() const noexcept; - - //! Return in insertion order - const std::deque& deterministic_vals() const noexcept; - //! Return all Vals in math expressions that cannot be eliminated. //! //! It is generally equivalent to vals that are used to generate @@ -183,11 +152,6 @@ class TORCH_CUDA_CU_API Fusion final { //! also included as they must show up in the final code. std::vector usedMathVals(); - //! Return the set of Exprs registered with this fusion. Warning: This will - //! return exprs outside inputs/outputs, so can be unsafe for use with - //! segmented fusions. - const std::unordered_set& unordered_exprs() const noexcept; - //! Return all Exprs that use val std::unordered_set unordered_uses(Val* val) const; @@ -197,12 +161,6 @@ class TORCH_CUDA_CU_API Fusion final { //! Indicate to kernel to set itself up to generate random numbers bool isStochastic(); - //! Indicate that the fusion contains reduction operations - bool hasReduction(); - - //! Indicate that the fusion contains welford operations - bool hasWelford(); - //! Run fusion segmentation algorithm to create a segmented fusion std::unique_ptr segment( const at::ArrayRef& inputs); @@ -217,9 +175,6 @@ class TORCH_CUDA_CU_API Fusion final { std::vector getTerminatingOutputs(); - bool hasInput(const Val* val) const; - bool hasOutput(const Val* val) const; - // Aliasing output to input value, this is a WAR to allow inplace update on // input tensor. // Note: this is not always safe and should be used with extra caution. @@ -262,36 +217,40 @@ class TORCH_CUDA_CU_API Fusion final { return is_during_update_uses_; } + const auto& ioAlias() const { + return io_alias_; + } + protected: friend SegmentCandidateFinder; friend SegmentedFusion; friend class TranslateApplicableWelford; + friend Val; static IrCloner copy(const Fusion* from, Fusion* to); - private: - // Return an int that monotonically increases for each val/expr, some are - // explicitly incremented by type. - StmtNameType getValName(ValType vtype); - StmtNameType getExprName(); + //! Register the Val with this fusion + virtual void registerVal(Val* val) override; + + //! Register expr with this fusion. + //! When we register an expression, we want to update the dependency tracking + //! of Vals. If this container is a not a Kernel, it will remove previous + //! definitions of outputs and register this Expr as the definition. Otherwise + //! will update definition if not previously set, but will not remove old + //! definitions. + virtual void registerExpr(Expr* expr) override; + //! Clear Expr's from TV uses that are not required to produce outputs from + //! inputs. Only other place this is used (other than Fusion) is in + //! Val::uses() + void resetTvUses(); + + private: // Determine if the two values are compatible for aliasing // Same DataType, ValType, and number of dimensions bool isAliasCompatible(Val* left, Val* right); private: - // Sets of all Vals/Exprs registered with this fusion - // (val_deque_ is not owning the objects) - std::unordered_set val_set_; - std::deque val_deque_; - std::unordered_set expr_set_; - - // Values names counters - std::unordered_map val_type_name_map_; - - // Expression names counter - StmtNameType expr_name_counter_ = 0; - // Fusion inputs and outputs std::vector inputs_; std::vector outputs_; diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp index 9ff25780814..fd7b6fc502a 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp @@ -322,7 +322,7 @@ void SegmentedFusion::draw() { for (auto group : groups()) { for (auto expr : group->exprs()) { - if (ir_utils::isTVOp(expr)) { + if (ir_utils::isTvOp(expr)) { expr_color_map[expr] = group_index; } } @@ -659,8 +659,8 @@ TensorView* castIntermediateValueInCompleteFusion( } // Create the actual domain and tv. - return new TensorView( - new TensorDomain( + return IrBuilder::create( + IrBuilder::create( new_root_domain, std::vector(new_root_domain.size(), true)), data_type); }; @@ -680,8 +680,8 @@ TensorView* castIntermediateValueInCompleteFusion( } // Insert the cast ops. - new UnaryOp(UnaryOpType::Cast, half_precision_tv, original_tv); - new UnaryOp(UnaryOpType::Cast, fp32_tv, half_precision_tv); + IrBuilder::create(UnaryOpType::Cast, half_precision_tv, original_tv); + IrBuilder::create(UnaryOpType::Cast, fp32_tv, half_precision_tv); // Return the new tv to replace original tv with // on the segmented edges. @@ -1740,9 +1740,10 @@ TranslateApplicableWelford::TranslateApplicableWelford( Fusion* fusion, const at::ArrayRef& runtime_inputs) : runtime_inputs_(runtime_inputs) { + auto exprs = fusion->exprs(); std::vector orignal_welfords( - ir_utils::filterByType(fusion->unordered_exprs()).begin(), - ir_utils::filterByType(fusion->unordered_exprs()).end()); + ir_utils::filterByType(exprs).begin(), + ir_utils::filterByType(exprs).end()); if (wouldTranslateToPersistent(orignal_welfords)) { for (auto welford : orignal_welfords) { @@ -1829,6 +1830,14 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( [&original_to_test_map](auto welford) { return original_to_test_map.clone(welford); }); + // Copied welfords will be invalidated on translation, but Vals will be + // reused, keep a reference to them. + std::vector welford_avgs; + std::vector welford_vars; + for (auto welford : copied_welfords) { + welford_avgs.push_back(welford->outAvg()); + welford_vars.push_back(welford->outVar()); + } // Translate the welford ops for (auto welford_to_translate : copied_welfords) { @@ -1860,6 +1869,21 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( return original_to_test_map.clone(out); }); + // If only average is used from welford, we should still translate, but we + // might not detect persistence if variance isn't actually used/marked as an + // output in the test. + for (auto outs_i : c10::irange(welford_avgs.size())) { + auto avg = welford_avgs[outs_i]; + auto var = welford_vars[outs_i]; + if (avg->uses().empty()) { + test_group_outputs_.push_back(avg); + } + + if (var->uses().empty()) { + test_group_outputs_.push_back(var); + } + } + // Temporarily localize test copy around // the group boundary FusionSegmentGuard fsg( @@ -1900,7 +1924,7 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { // Create scalar version of the feature element // counting. - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(1); std::vector broadcast_mask(in_root.size(), false); for (const auto i : c10::irange(in_root.size())) { if (out_root[i]->isReduction()) { @@ -1913,7 +1937,7 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { // Build a normalization expression group that is // equivalent to a welford operation. auto x_sum = sum(in_val, red_axes); - new BinaryOp(BinaryOpType::Div, out_avg, x_sum, num_features); + IrBuilder::create(BinaryOpType::Div, out_avg, x_sum, num_features); // welford.avg may be broadcast. Reuse it if found. TensorView* x_avg_bcast = nullptr; for (auto& use_expr : out_avg->uses()) { @@ -1949,8 +1973,12 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { } auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); - new ReductionOp(BinaryOpType::Add, new Double(0.0), out_var, x_mean_sub_pow); - new UnaryOp(UnaryOpType::Set, out_N, num_features); + IrBuilder::create( + BinaryOpType::Add, + IrBuilder::create(0.0), + out_var, + x_mean_sub_pow); + IrBuilder::create(UnaryOpType::Set, out_N, num_features); // out_avg, out_N are now outputs of a pointwise ops and we // need to clear out its reduction domains. @@ -2687,14 +2715,20 @@ void SegmentCandidateFinder::findSegments() { } } + auto reduction_ops = + ir_utils::getReductionOps(segmented_fusion_->completeFusion()); + auto welford_ops = ir_utils::filterByType(reduction_ops); + if (options_.run_translate_welford && - segmented_fusion_->completeFusion()->hasWelford()) { + (welford_ops.begin() != welford_ops.end())) { TranslateApplicableWelford::run(segmented_fusion_.get(), runtime_inputs_); } for (auto group : groups()) { - // Set heuristics in case single reduction kernels were left out - group->setHeuristic(deriveHeuristic(group)); + if (!group->outputs().empty()) { + // Set heuristics in case single reduction kernels were left out + group->setHeuristic(deriveHeuristic(group)); + } } // Remove all scalar edges since they do not represent actual @@ -2913,7 +2947,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { group->input_vals = IterVisitor::getInputsTo(group->inputs()); // Grab all expressions needed to produce to_visit - auto input_exprs = ExprSort::getExprs(completeFusion(), to_visit); + auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit); // Insert those expressions at the beginning of the group group->exprs_.insert( @@ -3102,7 +3136,7 @@ void SegmentedFusion::annotateFP16IntermediateTensors() { } } -TORCH_CUDA_CU_API std::string toString( +std::string toString( const SegmentCandidateFinderOptions& segment_options) { std::stringstream ss; ss << "segmentation phases {\n"; diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h index 61fa966348e..63124839fc1 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h @@ -288,11 +288,11 @@ class TORCH_CUDA_CU_API SegmentedFusion { } Val* findAlias(Val* val) const { - Val* alias_val = nullptr; - if (complete_fusion_->io_alias_.count(val) != 0) { - alias_val = complete_fusion_->io_alias_[val]; + auto alias_it = complete_fusion_->ioAlias().find(val); + if (alias_it != complete_fusion_->ioAlias().end()) { + return alias_it->second; } - return alias_val; + return nullptr; } //! Make a clone of the group and convert to fusion diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index 47c0316abda..b2d1f893ba6 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include #include @@ -19,6 +21,7 @@ #include #include +#include #include #include @@ -46,6 +49,13 @@ bool usedOnlyInDtype(Value* v) { Value* broadcastSizes(at::ArrayRef sizes) { AT_ASSERT(!sizes.empty()); Graph* graph = sizes[0]->owningGraph(); + Node* insertion_point = sizes[0]->node()->next(); + for (size_t i = 1; i < sizes.size(); i++) { + if (insertion_point->isBefore(sizes[i]->node()->next())) { + insertion_point = sizes[i]->node()->next(); + } + } + WithInsertPoint guard(insertion_point); Node* broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); broadcast_n->output()->setType(ListType::ofInts()); @@ -66,9 +76,13 @@ Value* createConditionalConstant(Node* profile_ivalue) { auto int_list = profile_ivalue->is(Symbol::attr("profiled_bool_list")); std::vector bool_list(int_list.begin(), int_list.end()); val = IValue(bool_list); - } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_size"))) { + } else if (profile_ivalue->hasAttribute( + Symbol::attr("profiled_reduction_size"))) { // int[] - val = IValue(profile_ivalue->is(Symbol::attr("profiled_size"))); + val = IValue(profile_ivalue->is(Symbol::attr("profiled_reduction_size"))); + } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_view_size"))) { + // int[] + val = IValue(profile_ivalue->is(Symbol::attr("profiled_view_size"))); } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool"))) { // bool val = IValue( @@ -101,6 +115,7 @@ struct CudaGraphFuser { std::unique_ptr aliasDb_; std::shared_ptr graph_; Symbol kind_ = prim::CudaFusionGroup; + std::unordered_map fusion_value_to_runtime_shape_; // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. // The specific limit is a function of constant memory size, amount available @@ -764,9 +779,11 @@ struct CudaGraphFuser { // longer valid so we rescan the new FusionGroup for more fusions... return std::make_pair(fusion_group.value()->reverseIterator(), true); } - // horizontal fusion only applies on tensor inputs + + // horizontal fusion only applies on non-scalar tensor inputs if (getHorizontalFusion() && - producer->type()->isSubtypeOf(*TensorType::get())) { + producer->type()->isSubtypeOf(*TensorType::get()) && + !is_cpu_scalar(*producer->type()->cast())) { // fusing nodes sharing inputs, this could save memory bandwidth by // reducing number of tensor read. for (const auto& u : producer->uses()) { @@ -838,6 +855,7 @@ struct CudaGraphFuser { // Builds up expressions that compute shapes of all intermediates (and // outputs) of the fusion group, based on the sizes of inputs. You should run // DCE to remove those that you end up not using. + // TODO: Add shape support for view, reshape, unsqueeze, and squeeze std::unordered_map buildShapeExpressions(Node* fusion_group) { WithInsertPoint insert_guard{fusion_group->next()}; std::unordered_map shape_of; @@ -850,7 +868,9 @@ struct CudaGraphFuser { AT_ASSERT(inputs.size() == sinputs.size()); for (const auto i : c10::irange(inputs.size())) { if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { - shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]}); + auto sinput_value = graph->insert(aten::size, {inputs[i]}); + shape_of[sinputs[i]] = sinput_value; + sinput_value->node()->moveBefore(fusion_group); } } @@ -869,6 +889,26 @@ struct CudaGraphFuser { } } + // Place all the shape expressions for intermediates in fusion + // before the CudaFusionGroup + graph->setInsertPoint(fusion_group); + + // hmmm, do I need to setInsertPoint... + const auto map_inputs = [&](Value* v) -> Value* { + // if constant ever has an input, it has to come from + // profile_ivalue dependency + if (v->node()->kind() == prim::Param && + fusion_group->input(v->offset())->node()->kind() == + prim::profile_ivalue) { + // we need to map it along profile_ivalue dependency + return fusion_group->input(v->offset()); + } else { + throw std::runtime_error( + std::string("unexpected input from node") + + v->node()->kind().toDisplayString()); + } + }; + for (Node* n : subgraph->nodes()) { // XXX: Use of shape_of.emplace is crucial to the output shape // optimization! @@ -912,21 +952,6 @@ struct CudaGraphFuser { n->input(2)->node()->kind() == prim::Constant, "only supports reduction axes and keepdim being constant"); - // hmmm, do I need to setInsertPoint... - const auto map_inputs = [&](Value* v) -> Value* { - // if constant ever has an input, it has to come from - // profile_ivalue dependency - if (v->node()->kind() == prim::Param && - fusion_group->input(v->offset())->node()->kind() == - prim::profile_ivalue) { - // we need to map it along profile_ivalue dependency - return fusion_group->input(v->offset()); - } else { - throw std::runtime_error( - std::string("unexpected input from node") + - v->node()->kind().toDisplayString()); - } - }; Node* in1_const = graph->createClone(n->input(1)->node(), map_inputs); graph->insertNode(in1_const); Node* in2_const = graph->createClone(n->input(2)->node(), map_inputs); @@ -1000,6 +1025,57 @@ struct CudaGraphFuser { } continue; } + if (n->kind() == aten::native_dropout) { + TORCH_INTERNAL_ASSERT( + shape_of.count(n->input(0)) > 0, + "buildShapeExpressions failed at accessing input shapes"); + shape_of.emplace(n->output(0), shape_of.at(n->input(0))); + shape_of.emplace(n->output(1), shape_of.at(n->input(0))); + continue; + } + if (n->kind() == prim::unsqueeze_copy) { + TORCH_INTERNAL_ASSERT( + shape_of.count(n->input(0)) > 0, + "buildShapeExpressions failed at accessing input shapes"); + TORCH_INTERNAL_ASSERT( + n->input(1)->node()->kind() == prim::Constant, + "only supports unsqueeze axes being constant"); + Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs); + graph->insertNode(dim_const); + std::vector inputs = { + shape_of.at(n->input(0)), dim_const->output()}; + Node* size_node = graph->insertNode(graph->create( + Symbol::fromQualString("prim::infer_unsqueeze_size"), inputs, 1)); + Value* size = size_node->output(0); + size->setType(ListType::ofInts()); + shape_of.emplace(n->output(), size); + continue; + } + if (n->kind() == prim::squeeze_copy) { + TORCH_INTERNAL_ASSERT( + shape_of.count(n->input(0)) > 0, + "buildShapeExpressions failed at accessing input shapes"); + TORCH_INTERNAL_ASSERT( + n->inputs().size() == 2 || n->inputs().size() == 1, + "prim::squeeze_copy expects one or two inputs"); + std::vector inputs = {shape_of.at(n->input(0))}; + + if (n->inputs().size() == 2) { + TORCH_INTERNAL_ASSERT( + n->input(1)->node()->kind() == prim::Constant, + "only supports squeeze axes being constant"); + Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs); + graph->insertNode(dim_const); + inputs.push_back(dim_const->output()); + } + Node* size_node = graph->insertNode(graph->create( + Symbol::fromQualString("prim::infer_squeeze_size"), inputs, 1)); + Value* size = size_node->output(0); + size->setType(ListType::ofInts()); + shape_of.emplace(n->output(), size); + continue; + } + auto tensor_inputs = filter(n->inputs(), [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); @@ -1025,8 +1101,9 @@ struct CudaGraphFuser { // TODO: failure in buildShapeExpressions should not break fusion execution, // we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize. GRAPH_DEBUG("before build shape expression: ", *graph_); - auto shape_of = buildShapeExpressions(fusion_group); + fusion_value_to_runtime_shape_ = buildShapeExpressions(fusion_group); GRAPH_DEBUG("after build shape expression: ", *graph_); + auto outputs = fusion_group->outputs().vec(); auto soutputs = subgraph->outputs().vec(); // XXX: Iterating in this order is not only good for performance reasons! @@ -1035,12 +1112,14 @@ struct CudaGraphFuser { for (int64_t i = static_cast(outputs.size()) - 1; i >= 0; --i) { auto output = outputs[i]; auto soutput = soutputs[i]; - if (usedOnlyInDtypeAndSize(output) && shape_of.count(soutput) > 0) { + if (usedOnlyInDtypeAndSize(output) && + fusion_value_to_runtime_shape_.count(soutput) > 0) { bool has_dtype = usedInDtype(output); auto uses = output->uses(); for (Use u : uses) { if (u.user->matches("aten::size(Tensor self) -> int[]")) { - u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); + u.user->output()->replaceAllUsesWith( + fusion_value_to_runtime_shape_.at(soutput)); u.user->destroy(); } else if (u.user->matches("prim::dtype(Tensor a) -> int")) { continue; @@ -1286,6 +1365,55 @@ void PeepholeOptimizeShapeExpressions(Block* block) { } } +// view_sizes_runtime is the profiled-ivalue argument for view-size. +// view_sizes_constant_list is the constant list recorded during profiling runs. +Value* guardView( + Node* fusion, + std::unordered_map& fusion_value_to_runtime_size, + Node* versioning_if, + Node* view, + Value* view_sizes_runtime) { + // 1. Get self tensor sizes and view_sizes + auto self_value = view->inputs().front(); + auto self_type = self_value->type()->cast(); + auto self_sizes_constant_list = getTensorSizes(self_type); + + auto view_sizes_constant_list = + constant_as>(view->inputs().back()); + TORCH_INTERNAL_ASSERT(view_sizes_constant_list.has_value()); + + // 2. Get constraints for self tensor and view_sizes + auto constraints = analyzeViewConstraint( + self_sizes_constant_list, view_sizes_constant_list->vec()); + + // 3. Add constraints as constant to graph + auto self_tensor_constraint = fusion->owningGraph()->insertConstant( + IValue(constraints.original_constraint)); + self_tensor_constraint->node()->moveBefore(versioning_if); + auto view_sizes_constraint = + fusion->owningGraph()->insertConstant(IValue(constraints.new_constraint)); + view_sizes_constraint->node()->moveBefore(versioning_if); + + // 4. Create CudaFusionViewGuard using input tensor, profile_ivalue + // for view_sizes list, and constraints + TORCH_INTERNAL_ASSERT( + fusion_value_to_runtime_size.find(self_value) != + fusion_value_to_runtime_size.end(), + "Failed to find runtime size for fusion value:\t", + self_value->node()->kind().toDisplayString()); + Node* viewcheck_node = + fusion->owningGraph() + ->create( + c10::Symbol::fromQualString("prim::CudaFusionViewGuard"), + {fusion_value_to_runtime_size.at(self_value), + view_sizes_runtime, + self_tensor_constraint, + view_sizes_constraint}, + 1) + ->insertBefore(versioning_if); + return viewcheck_node->output(); +} + //! [ Note -- CudaFusionGuard implementation ] //! //! shamelessly copying code from NNC (tensorexpr_fuser) with very little @@ -1324,7 +1452,9 @@ void PeepholeOptimizeShapeExpressions(Block* block) { //! //! TODO: we also need to assert/check reduction axes and replace it with //! constants in `CudaFusionGroup` -void guardFusionGroup(Node* fusion) { +void guardFusionGroup( + Node* fusion, + std::unordered_map& fusion_value_to_runtime_size) { // Fixup types of the subgraph inputs std::vector guard_types; std::vector tensor_inputs_to_check; @@ -1375,10 +1505,12 @@ void guardFusionGroup(Node* fusion) { versioning_if->insertAfter(typecheck_node); + auto fusion_graph = fusion->g(attr::Subgraph); + std::vector check_flags = {}; + // Fill in the false block. It should contain the unoptimized // copy of the fused subgraph, unless we have conditional constants from // profiled_ivalue; - auto fusion_graph = fusion->g(attr::Subgraph); std::shared_ptr fb_graph; // resource holder; // Restore the dependency for constant introduced by profiled_ivalue within // the graph. @@ -1425,11 +1557,10 @@ void guardFusionGroup(Node* fusion) { // 2. REMOVE conditional constant dependency in fusion group size_t compensation = 0; - // get a constant false, which is used by `and` pattern later + // get a constant true, which is used by `and` pattern later auto const_true = fusion->owningGraph()->insertConstant(IValue(true)); const_true->node()->moveBefore(versioning_if); - std::vector check_flags = {}; for (const auto& original_offset : profiled_ivalue_indices) { size_t offset = original_offset - compensation; @@ -1457,7 +1588,7 @@ void guardFusionGroup(Node* fusion) { ->insertBefore(versioning_if) ->output(); } else if (fusion->input(offset)->node()->hasAttribute( - Symbol::attr("profiled_size"))) { + Symbol::attr("profiled_reduction_size"))) { // TODO(profile_size): check sizes here with special size comparison op // TORCH_INTERNAL_ASSERT(false, "not implemented yet"); ivalue_check = @@ -1468,6 +1599,28 @@ void guardFusionGroup(Node* fusion) { 1) ->insertBefore(versioning_if) ->output(); + } else if (fusion->input(offset)->node()->hasAttribute( + Symbol::attr("profiled_view_size"))) { + // TODO: Add support for dynamic split to view guard + + // Path from profile-ivalue to prim::view_copy operation + // profile-ivalue -> Uses: [Constant, CudaFusionGroup] + // Get argument position in CudaFusionGroup + // Get argument in subgraph for CudaFusionGroup + // CudaFusionGroup argument -> Constant List -> prim::view_copy + auto cuda_fusion_group_arg = profiled_ival->uses().back().offset; + auto subgraph_arg = fusion_graph->inputs()[cuda_fusion_group_arg]; + auto constant = subgraph_arg->uses().front().user->output(); + auto view = constant->uses().front().user; + TORCH_INTERNAL_ASSERT( + view->kind() == prim::view_copy || + view->kind() == prim::reshape_copy); + ivalue_check = guardView( + fusion, + fusion_value_to_runtime_size, + versioning_if, + view, + profiled_ival); } else { ivalue_check = fusion->owningGraph() ->create(aten::eq, {profiled_ival, const_o}, 1) @@ -1495,22 +1648,24 @@ void guardFusionGroup(Node* fusion) { fusion_graph->eraseInput(offset); compensation++; } - - if (!check_flags.empty()) { - // attaching output from CudaFusionGuard to profile ivalue checks - check_flags.emplace_back(typecheck_result); - auto graph = fusion->owningGraph(); - auto bool_list_node = - graph->insertNode(graph->createList(BoolType::get(), check_flags)); - bool_list_node->moveBefore(versioning_if); - Value* bool_list = bool_list_node->output(); - // new typecheck_result - typecheck_result = graph->insert(aten::all, {bool_list}); - typecheck_result->node()->moveBefore(versioning_if); - } // update graph in fusion node fusion->g_(attr::Subgraph, fusion_graph); - } else { + } + + if (!check_flags.empty()) { + // attaching output from CudaFusionGuard to profile ivalue checks + check_flags.emplace_back(typecheck_result); + auto graph = fusion->owningGraph(); + auto bool_list_node = + graph->insertNode(graph->createList(BoolType::get(), check_flags)); + bool_list_node->moveBefore(versioning_if); + Value* bool_list = bool_list_node->output(); + // new typecheck_result + typecheck_result = graph->insert(aten::all, {bool_list}); + typecheck_result->node()->moveBefore(versioning_if); + } + + if (profiled_ivalue_indices.empty()) { WithInsertPoint guard(false_block->return_node()); const auto subgraph_outputs = insertGraph(*fusion->owningGraph(), *fusion_graph, fusion->inputs()); @@ -1536,11 +1691,13 @@ void guardFusionGroup(Node* fusion) { } } -void guardFusionGroups(Block* block) { +void guardFusionGroups( + Block* block, + std::unordered_map& fusion_value_to_runtime_size) { std::vector fusions; for (Node* n : block->nodes()) { for (Block* b : n->blocks()) { - guardFusionGroups(b); + guardFusionGroups(b, fusion_value_to_runtime_size); } if (n->kind() == prim::CudaFusionGroup) { fusions.push_back(n); @@ -1550,7 +1707,7 @@ void guardFusionGroups(Block* block) { // step 1: a. add prim::CudaFusionGuard and fallback logic // b. insert guard logic of profile_ivalue with if block // c. restore conditional constant to non-constant for fallback - guardFusionGroup(fusion); + guardFusionGroup(fusion, fusion_value_to_runtime_size); } } @@ -1918,6 +2075,85 @@ void decomposeLinearOps(Block* block) { } } +// Replace 'operation' with 'operation_copy' to guard alias operations. +// Supports View, Reshape, Squeeze, and Unsqueeze +void replaceAliasOpsWithCopy(std::shared_ptr& graph, Block* block) { + static std::unordered_map op_mapping( + {{aten::view, prim::view_copy}, + {aten::reshape, prim::reshape_copy}, + {aten::squeeze, prim::squeeze_copy}, + {aten::unsqueeze, prim::unsqueeze_copy}}); + + std::vector maybe_alias_nodes; + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + replaceAliasOpsWithCopy(graph, b); + } + if (op_mapping.find(n->kind()) != op_mapping.end()) { + maybe_alias_nodes.push_back(n); + } + } + + auto alias_db = std::make_unique(graph); + for (Node* n : maybe_alias_nodes) { + if (!alias_db->safeToChangeAliasingRelationship( + n->input(0), n->output(0))) { + continue; + } + + WithInsertPoint guard(n); + auto op_copy = + graph->insertNode(graph->create(op_mapping[n->kind()], n->inputs(), 1)); + op_copy->output()->setType(n->output(0)->type()); + + // adding newly created value into alias_db; + alias_db->createValue(op_copy->output()); + + n->output()->replaceAllUsesWith(op_copy->output()); + n->destroy(); + } +} + +// Revert all 'op_copy' with 'op' except in CudaFusionGroup +// e.g., Any non-fused alias operation including within the prim::FallbackGraph +// Supports View, Reshape, Squeeze, and Unsqueeze +void revertAliasCopyOps(std::shared_ptr& graph, Block* block) { + static std::unordered_map op_mapping( + {{prim::view_copy, aten::view}, + {prim::reshape_copy, aten::reshape}, + {prim::squeeze_copy, aten::squeeze}, + {prim::unsqueeze_copy, aten::unsqueeze}}); + + std::vector alias_copy_ops; + for (Node* n : block->nodes()) { + // Allow alias copy ops in CudaFusionGroup + if (n->kind() == prim::CudaFusionGroup) { + continue; + } + // Revert alias copy ops within FallbackGraph + if (n->kind() == prim::FallbackGraph) { + auto subgraph = n->g(attr::Subgraph); + revertAliasCopyOps(subgraph, subgraph->block()); + } + for (Block* b : n->blocks()) { + revertAliasCopyOps(graph, b); + } + // Revert any non-fused alias copy ops + if (op_mapping.find(n->kind()) != op_mapping.end()) { + alias_copy_ops.push_back(n); + } + } + + for (Node* n : alias_copy_ops) { + WithInsertPoint guard(n); + auto reverted_op = + graph->insertNode(graph->create(op_mapping[n->kind()], n->inputs(), 1)); + reverted_op->output()->setType(n->output(0)->type()); + n->output()->replaceAllUsesWith(reverted_op->output()); + n->destroy(); + } +} + // break `conv2d` layer into `conv2d` and `add_optional`. This allows us to fuse // the binary operation without supporting gemm. // Note that we are not breaking `conv2d` layer without bias. @@ -2030,12 +2266,16 @@ void CudaFuseGraph(std::shared_ptr& graph) { decomposeConvOps(graph->block()); GRAPH_DEBUG("After decompose decompose Conv Ops by nvfuser: ", *graph); - CudaGraphFuser(graph->block(), graph).run(); + replaceAliasOpsWithCopy(graph, graph->block()); + GRAPH_DEBUG("replace alias_op with alias_copy by nvfuser: ", *graph); + + CudaGraphFuser cgf(graph->block(), graph); + cgf.run(); GRAPH_DEBUG("After Fusion: ", *graph); // guard input types as well as conditional constants from // aten::profile_ivalue - guardFusionGroups(graph->block()); + guardFusionGroups(graph->block(), cgf.fusion_value_to_runtime_shape_); GRAPH_DEBUG("After Guard Fusion: ", *graph); // mutate `aten::_batch_norm_impl_index` and @@ -2053,6 +2293,10 @@ void CudaFuseGraph(std::shared_ptr& graph) { // optimization targeting AMP removeOutputUsedOnlyInDtype(graph->block()); GRAPH_DEBUG("After removeOutputUsedOnlyInDtype: ", *graph); + + revertAliasCopyOps(graph, graph->block()); + GRAPH_DEBUG("revert alias_copy ops by nvfuser: ", *graph); + // After FuseGraph some common subexpressions may come back EliminateCommonSubexpression(graph); // We might have emitted a fair amount of useless shape propagating code, so diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 39176a60c53..8e151372b75 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -10,9 +10,8 @@ #include #include #include -#include -#include #include +#include #include #include #include @@ -44,9 +43,9 @@ class ContigIDs : public OptInDispatch { using OptInDispatch::handle; // Mark if ids are result of contigous merges - std::unordered_set contig_ids; + std::unordered_set contig_ids; // Given contiguous domain, return all iter domains within its history. - std::unordered_map> + std::unordered_map> within_contig_ids; const std::vector& root_domain_; const std::vector& root_contiguity_; @@ -58,7 +57,7 @@ class ContigIDs : public OptInDispatch { }); } - bool isContig(kir::IterDomain* id) { + bool isContig(IterDomain* id) { return contig_ids.find(id) != contig_ids.end(); } @@ -66,14 +65,11 @@ class ContigIDs : public OptInDispatch { void handle(Split*) override {} void handle(Merge* merge) override { - const auto gpu_lower = GpuLower::current(); - // If either input is non-contiguous so is output. const auto inner = merge->inner(); const auto outer = merge->outer(); - if ((!isContig(gpu_lower->lowerValue(inner)->as()) || - !isContig(gpu_lower->lowerValue(outer)->as()))) { + if (!isContig(inner) || !isContig(outer)) { return; } @@ -136,38 +132,34 @@ class ContigIDs : public OptInDispatch { // If we matched all inputs, the output is contiguous. Only want to keep the // top contig ID, lower ids should be placed in the "within_contig_ids" map // of top id. - auto kir_inner = - gpu_lower->lowerValue(merge->inner())->as(); - auto kir_outer = - gpu_lower->lowerValue(merge->outer())->as(); - auto kir_out = gpu_lower->lowerValue(merge->out())->as(); + auto out = merge->out()->as(); if (ordered_inputs.empty()) { - if (contig_ids.find(kir_inner) != contig_ids.end()) { - contig_ids.erase(kir_inner); + if (contig_ids.find(inner) != contig_ids.end()) { + contig_ids.erase(inner); } - if (contig_ids.find(kir_outer) != contig_ids.end()) { - contig_ids.erase(kir_outer); + if (contig_ids.find(outer) != contig_ids.end()) { + contig_ids.erase(outer); } - contig_ids.emplace(kir_out); + contig_ids.emplace(out); - std::unordered_set within_out; - within_out.emplace(kir_inner); - if (within_contig_ids.find(kir_inner) != within_contig_ids.end()) { - auto in_inner = within_contig_ids.at(kir_inner); + std::unordered_set within_out; + within_out.emplace(inner); + if (within_contig_ids.find(inner) != within_contig_ids.end()) { + auto in_inner = within_contig_ids.at(inner); within_out.insert(in_inner.begin(), in_inner.end()); - within_contig_ids.erase(kir_inner); + within_contig_ids.erase(inner); } - within_out.emplace(kir_outer); - if (within_contig_ids.find(kir_outer) != within_contig_ids.end()) { - auto in_outer = within_contig_ids.at(kir_outer); + within_out.emplace(outer); + if (within_contig_ids.find(outer) != within_contig_ids.end()) { + auto in_outer = within_contig_ids.at(outer); within_out.insert(in_outer.begin(), in_outer.end()); - within_contig_ids.erase(kir_outer); + within_contig_ids.erase(outer); } - within_contig_ids[kir_out] = within_out; + within_contig_ids[out] = within_out; } } @@ -195,8 +187,6 @@ class ContigIDs : public OptInDispatch { " != ", root_contiguity_.size()); - const auto gpu_lower = GpuLower::current(); - for (const auto i : c10::irange(root_domain_.size())) { // If a root domain has halo, can't use merged domain even if // both inputs are contiguous. HaloInfo is also initialized for @@ -204,32 +194,32 @@ class ContigIDs : public OptInDispatch { // RootAxisInfo. This should be safe as no rfactor tensor should // need halo. if (root_contiguity_[i] && - !gpu_lower->haloInfo().getRootAxisInfo(root_domain_[i]).hasHalo()) { - auto kir_root_domain_i = - gpu_lower->lowerValue(root_domain_[i])->as(); - contig_ids.emplace(kir_root_domain_i); - within_contig_ids[kir_root_domain_i] = - std::unordered_set(); + !GpuLower::current() + ->haloInfo() + .getRootAxisInfo(root_domain_[i]) + .hasHalo()) { + auto root_domain_i = root_domain_[i]->as(); + contig_ids.emplace(root_domain_i); + within_contig_ids[root_domain_i] = std::unordered_set(); is_contig_root[root_domain_[i]] = true; } else { is_contig_root[root_domain_[i]] = false; } } - auto exprs = ExprSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()}); + auto exprs = StmtSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()}); for (auto expr : exprs) { handle(expr); } } - const std::unordered_set contigIDs() const { + const std::unordered_set contigIDs() const { return contig_ids; } - const std:: - unordered_map> - withinContigIDs() const { + const std::unordered_map> + withinContigIDs() const { return within_contig_ids; } }; @@ -276,21 +266,18 @@ void updateHaloInfoForReference( // // ref_map: ref-to-consumer in consumer indexing; ref-to-producer in // producer indexing -std::unordered_map getReferenceHaloExtentMap( +std::unordered_map getReferenceHaloExtentMap( const ReferenceTensor& reference, const std::unordered_map& index_map_from_ref) { - const auto gpu_lower = GpuLower::current(); - - const auto& halo_info = gpu_lower->haloInfo(); + const auto& halo_info = GpuLower::current()->haloInfo(); - std::unordered_map reference_halo_extent_map; + std::unordered_map reference_halo_extent_map; // Propagate halo extents of the reference to the consumer or // producer tensor for (auto kv : index_map_from_ref) { - auto ref_id = gpu_lower->lowerValue(kv.first)->as(); - auto producer_or_consumer_id = - gpu_lower->lowerValue(kv.second)->as(); + auto ref_id = kv.first; + auto producer_or_consumer_id = kv.second; auto extent = halo_info.getExtent(ref_id); if (extent != nullptr) { reference_halo_extent_map[producer_or_consumer_id] = extent; @@ -302,7 +289,7 @@ std::unordered_map getReferenceHaloExtentMap( //! Offset of an index of a producer axis with respect to its //! corresponding consumer index -kir::Val* getProducerHaloOffset( +int getProducerHaloOffset( const TensorView* producer_tv, size_t producer_axis, const TensorView* consumer_tv) { @@ -325,41 +312,31 @@ kir::Val* getProducerHaloOffset( const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0); const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - kir::Val* offset = (p_pad->isConst() && c_pad->isConst()) - ? ir_builder.create( - p_pad->value().value() - c_pad->value().value()) - : ir_builder.subExpr(p_pad, c_pad); + auto offset = p_pad - c_pad; // If the consumer is a result of shifting the producer, adjust the // producer index per the offsets argument of the shift op. if (auto shift_op = dynamic_cast(consumer_tv->definition())) { - offset = ir_builder.subExpr( - offset, ir_builder.create(shift_op->offset(producer_axis))); + offset -= shift_op->offset(producer_axis); } return offset; } //! Offset producer index when necessary -kir::Val* getProducerIndexWithHalo( +Val* getProducerIndexWithHalo( const TensorView* producer_tv, size_t producer_axis, - kir::Val* producer_index, + Val* producer_index, const TensorView* consumer_tv) { const auto offset = getProducerHaloOffset(producer_tv, producer_axis, consumer_tv); - if (offset->isZeroInt()) { + if (offset == 0) { return producer_index; } - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - producer_index = ir_builder.addExpr(producer_index, offset); + producer_index = SimplifyingIrBuilder::addExpr(producer_index, offset); return producer_index; } @@ -368,58 +345,58 @@ kir::Val* getProducerIndexWithHalo( //! //! \param consumer_root_axis Position of corresponding consumer axis //! \param consumer_tv Consumer TensorView +//! \param index_map Mappings from consumer or reference to indices +//! \param use_reference_map True when index_map maps reference domains //! \param concrete_to_ref_map Mappings from concrete to reference domains -//! \param ref_index_map Mappings from reference domains to indices -kir::Val* getProducerOffsetWithGather( +Val* getProducerOffsetWithGather( size_t consumer_root_axis, const TensorView* consumer_tv, - const std::unordered_map& concrete_to_ref_map, - const std::unordered_map& ref_index_map) { + const std::unordered_map& index_map, + bool use_reference_map = false, + const std::unordered_map& concrete_to_ref_map = + {}) { const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); const auto gather_expr = dynamic_cast(consumer_tv->definition()); if (gather_expr == nullptr) { - return ir_builder.zeroVal(); + return gpu_lower->kernel()->zeroVal(); } // If the window extent is one, no specific offsetting // is necessary if (consumer_root_axis >= gather_expr->windowShape().size() || - gather_expr->windowShape()[consumer_root_axis]->isOneInt()) { - return ir_builder.zeroVal(); + gather_expr->windowShape()[consumer_root_axis] == 1) { + return gpu_lower->kernel()->zeroVal(); } // Basically, the goal is to build an expression of producer_index + // window_index, so we first need to locate the index expression // that corresponds to the window axis of this producer axis. - // Locate the root IterDomain of the reference that corresponds to the gather - // axis const auto window_axis = gather_expr->gatherAxis(consumer_root_axis); auto window_id = consumer_tv->getRootDomain().at(window_axis); - auto concrete_window_id = - gpu_lower->caIndexMap().getConcreteMappedID(window_id); - auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id); - TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end()); - IterDomain* reference_root_of_gather_axis = concrete_2_ref_it->second; - - // Now that reference_root_of_gather_axis is the IterDomain for the - // window axis, take its corresponding index from the index map - auto window_idx = - ref_index_map.at(gpu_lower->lowerValue(reference_root_of_gather_axis) - ->as()); - - // Positive (or negative) padding at offset zero means the indexing - // shifted to the negative (or positive) direction. + + // When index_map maps a reference tensor, find the corresponding + // reference ID of window_id. + if (use_reference_map) { + auto concrete_window_id = + gpu_lower->caIndexMap().getConcreteMappedID(window_id); + auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id); + TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end()); + window_id = concrete_2_ref_it->second; + } + + auto window_idx = index_map.at(window_id); + + // Positive padding at offset zero means the indexing shifted to the + // negative direction. auto pad_width = gather_expr->padWidth()[consumer_root_axis][0]; // producer offset: window_index - padding - auto producer_offset = - ir_builder.subExpr(window_idx, ir_builder.create(pad_width)); + auto producer_offset = SimplifyingIrBuilder::subExpr( + window_idx, IrBuilder::create(pad_width)); return producer_offset; - ; } //! Offset a producer index of a gather expression @@ -428,13 +405,13 @@ kir::Val* getProducerOffsetWithGather( //! expression that accesses a window position that the current loop //! structure refers to. Use getGatherProducerOffset to create an //! offset Val. -kir::Val* getProducerIndexWithGather( - kir::Val* producer_index, +Val* getProducerIndexWithGather( + Val* producer_index, size_t producer_root_axis, const TensorView* producer_tv, const TensorView* consumer_tv, const std::unordered_map& concrete_to_ref_map, - const std::unordered_map& ref_index_map) { + const std::unordered_map& ref_index_map) { auto gather_op = dynamic_cast(consumer_tv->definition()); // Just return the producer index as is if this is not a gather @@ -460,22 +437,18 @@ kir::Val* getProducerIndexWithGather( ", producer_axis: ", producer_root_axis); - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); auto offset = getProducerOffsetWithGather( - consumer_axis, consumer_tv, concrete_to_ref_map, ref_index_map); - return ir_builder.addExpr(producer_index, offset); + consumer_axis, consumer_tv, ref_index_map, true, concrete_to_ref_map); + return SimplifyingIrBuilder::addExpr(producer_index, offset); } // Adjusts a global consumer index when its root domain is partially // split. Note that non-global consumer indices don't need any // adjustment. -kir::Val* getGlobalConsumerOffsetWithPartialSplit(kir::IterDomain* root_id) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto offset = gpu_lower->partialSplitMap().getStartOffset(root_id); +Val* getGlobalConsumerOffsetWithPartialSplit(IterDomain* root_id) { + auto offset = GpuLower::current()->partialSplitMap().getStartOffset(root_id); if (offset == nullptr) { - return ir_builder.zeroVal(); + return GpuLower::current()->kernel()->zeroVal(); } else { return offset; } @@ -488,13 +461,12 @@ kir::Val* getGlobalConsumerOffsetWithPartialSplit(kir::IterDomain* root_id) { // it needs to be added to the index. Also, when the producer itself // also has a non-zero split offset, that needs to be subtracted from // the index. -kir::Val* getProducerIndexWithPartialSplit( - kir::Val* producer_index, +Val* getProducerIndexWithPartialSplit( + Val* producer_index, IterDomain* producer_root_id, const TensorView* producer_tv, const TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); auto p2c = PairwiseRootDomainMap(producer_tv, consumer_tv) @@ -509,31 +481,29 @@ kir::Val* getProducerIndexWithPartialSplit( auto consumer_offset = gpu_lower->partialSplitMap().getStartOffset(consumer_root_id); - auto consumer_offset_kir = consumer_offset == nullptr - ? ir_builder.zeroVal() - : gpu_lower->lowerValue(consumer_offset); + consumer_offset = consumer_offset == nullptr ? gpu_lower->kernel()->zeroVal() + : consumer_offset; auto producer_offset = gpu_lower->partialSplitMap().getStartOffset(producer_root_id); - auto producer_offset_kir = producer_offset == nullptr - ? ir_builder.zeroVal() - : gpu_lower->lowerValue(producer_offset); + producer_offset = producer_offset == nullptr ? gpu_lower->kernel()->zeroVal() + : producer_offset; // If the producer is on global memory, it's always allocated // without trimming the out-of-bounds region, so the consumer offset // should be added to the index. if (producer_tv->getMemoryType() == MemoryType::Global) { - if (consumer_offset_kir->isZeroInt()) { + if (consumer_offset->isZeroInt()) { return producer_index; } else { - return ir_builder.addExpr(producer_index, consumer_offset_kir); + return IrBuilder::addExpr(producer_index, consumer_offset); } } // Non-global case. Difference of the split offsets must be // accounted. - auto diff = ir_builder.subExpr(consumer_offset_kir, producer_offset_kir); + auto diff = IrBuilder::subExpr(consumer_offset, producer_offset); kir::ExpressionEvaluator ee; auto diff_eval = ee.evaluate(diff); // We currently only allow constant offsetting @@ -543,19 +513,16 @@ kir::Val* getProducerIndexWithPartialSplit( return producer_index; } - return ir_builder.addExpr( - producer_index, ir_builder.create(diff_eval.value())); + return IrBuilder::addExpr( + producer_index, IrBuilder::create(diff_eval.value())); } } // namespace void IndexCompute::handle(Split* split) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto in_id = gpu_lower->lowerValue(split->in())->as(); - auto outer_id = gpu_lower->lowerValue(split->outer())->as(); - auto inner_id = gpu_lower->lowerValue(split->inner())->as(); + auto in_id = split->in()->as(); + auto outer_id = split->outer()->as(); + auto inner_id = split->inner()->as(); auto outer_it = index_map_.find(outer_id); auto inner_it = index_map_.find(inner_id); @@ -588,8 +555,8 @@ void IndexCompute::handle(Split* split) { } if (isZero(in_id)) { - index_map_[in_id] = ir_builder.create(0); - extent_map_[in_id] = ir_builder.create(0); + index_map_[in_id] = GpuLower::current()->kernel()->zeroVal(); + extent_map_[in_id] = GpuLower::current()->kernel()->zeroVal(); } else if (zero_merged_in && outer_zero) { index_map_[in_id] = inner_ind; extent_map_[in_id] = getExtent(inner_id); @@ -597,24 +564,21 @@ void IndexCompute::handle(Split* split) { index_map_[in_id] = outer_ind; extent_map_[in_id] = getExtent(outer_id); } else { - index_map_[in_id] = ir_builder.addExpr( - ir_builder.mulExpr(outer_ind, getExtent(inner_id)), inner_ind); + index_map_[in_id] = IrBuilder::addExpr( + IrBuilder::mulExpr(outer_ind, getExtent(inner_id)), inner_ind); // The extent should be updated only when its allocation is // partial, i.e., zero_merged_in is true. See PR #1270. if (zero_merged_in) { extent_map_[in_id] = - ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id)); + IrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id)); } } } void IndexCompute::handle(Merge* merge) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto out_id = gpu_lower->lowerValue(merge->out())->as(); - auto outer_id = gpu_lower->lowerValue(merge->outer())->as(); - auto inner_id = gpu_lower->lowerValue(merge->inner())->as(); + auto out_id = merge->out(); + auto outer_id = merge->outer(); + auto inner_id = merge->inner(); auto out_it = index_map_.find(out_id); if (out_it == index_map_.end()) { @@ -622,7 +586,7 @@ void IndexCompute::handle(Merge* merge) { } auto out_ind = out_it->second; - auto zero = ir_builder.zeroVal(); + auto zero = GpuLower::current()->kernel()->zeroVal(); if (isZero(out_id)) { index_map_[outer_id] = zero; @@ -643,17 +607,14 @@ void IndexCompute::handle(Merge* merge) { TORCH_INTERNAL_ASSERT(!input_ids.empty()); for (auto root_id : input_ids) { - index_map_[gpu_lower->lowerValue(root_id)->as()] = zero; + index_map_[root_id] = zero; } - index_map_[gpu_lower - ->lowerValue(*(input_ids.end() - 1)) - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ->as()] = out_ind; + index_map_[*(input_ids.end() - 1)] = out_ind; return; } - kir::Val* inner_extent = getExtent(inner_id); + Val* inner_extent = getExtent(inner_id); // When the reference has halo extent for inner_id, that extent needs to // be used to un-merge @@ -718,8 +679,8 @@ void IndexCompute::handle(Merge* merge) { zero_merged_in_.emplace(inner_id); zero_merged_in_.emplace(outer_id); } else { - index_map_[outer_id] = ir_builder.divExpr(out_ind, inner_extent); - index_map_[inner_id] = ir_builder.modExpr(out_ind, inner_extent); + index_map_[outer_id] = IrBuilder::divExpr(out_ind, inner_extent); + index_map_[inner_id] = IrBuilder::modExpr(out_ind, inner_extent); } } @@ -739,13 +700,13 @@ void IndexCompute::handle(Expr* e) { // using TransformIter::runBackward; IndexCompute::IndexCompute( const TensorDomain* _td, - std::unordered_map initial_index_map, - std::unordered_map extent_map, - std::unordered_set zero_domains, - std::unordered_set zero_merged_in, + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in, const std::vector& root_contiguity, - std::unordered_set preferred_paths, - std::unordered_map reference_halo_extent_map) + std::unordered_set preferred_paths, + std::unordered_map reference_halo_extent_map) : td_(_td), index_map_(std::move(initial_index_map)), extent_map_(std::move(extent_map)), @@ -783,7 +744,7 @@ void IndexCompute::run() { traverseFrom(td_->fusion(), domain_vals, false); } -kir::Val* IndexCompute::getExtent(kir::IterDomain* id) { +Val* IndexCompute::getExtent(IterDomain* id) { // Pick from extent_map_ if available. Previously parallel // dimensions were ued (e.g., blockDim.x), however, it would result // in out-of-bounds errors when the extent of IterDomain is smaller @@ -795,11 +756,11 @@ kir::Val* IndexCompute::getExtent(kir::IterDomain* id) { } } -bool IndexCompute::hasZeroMerged(kir::IterDomain* id) const { +bool IndexCompute::hasZeroMerged(IterDomain* id) const { return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id); } -bool IndexCompute::isZero(kir::IterDomain* id) const { +bool IndexCompute::isZero(IterDomain* id) const { return zero_domains_.find(id) != zero_domains_.end(); } @@ -807,22 +768,17 @@ IndexCompute IndexCompute::updateIndexCompute( const TensorDomain* new_td, const std::unordered_map& id_map, const std::vector& root_contiguity, - const std::unordered_map& - reference_halo_extent_map) { + const std::unordered_map& reference_halo_extent_map) { FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute"); - const auto gpu_lower = GpuLower::current(); - - std::unordered_map updated_index_map; - std::unordered_map updated_extent_map; - std::unordered_set updated_zero_domains; - std::unordered_set updated_zero_merged_in; + std::unordered_map updated_index_map; + std::unordered_map updated_extent_map; + std::unordered_set updated_zero_domains; + std::unordered_set updated_zero_merged_in; for (auto id_entry : id_map) { - kir::IterDomain* prev_id = - gpu_lower->lowerValue(id_entry.first)->as(); - kir::IterDomain* new_id = - gpu_lower->lowerValue(id_entry.second)->as(); + IterDomain* prev_id = id_entry.first; + IterDomain* new_id = id_entry.second; if (index_map_.find(prev_id) != index_map_.end()) { updated_index_map[new_id] = index_map_.at(prev_id); @@ -859,8 +815,8 @@ class UpdateLeafIndices : public IterVisitor { public: UpdateLeafIndices( const TensorDomain* td, - std::unordered_map initial_index_map, - std::unordered_map extent_map) + std::unordered_map initial_index_map, + std::unordered_map extent_map) : td_(td), index_map_(std::move(initial_index_map)), extent_map_(std::move(extent_map)) { @@ -870,11 +826,11 @@ class UpdateLeafIndices : public IterVisitor { traverseFrom(td_->fusion(), domain_vals, false); } - const std::unordered_map& indexMap() const { + const std::unordered_map& indexMap() const { return index_map_; } - const std::unordered_map& extentMap() const { + const std::unordered_map& extentMap() const { return extent_map_; } @@ -882,13 +838,9 @@ class UpdateLeafIndices : public IterVisitor { using IterVisitor::handle; void handle(Split* split) override { - const auto gpu_lower = GpuLower::current(); - - auto in_id = gpu_lower->lowerValue(split->in())->as(); - auto outer_id = - gpu_lower->lowerValue(split->outer())->as(); - auto inner_id = - gpu_lower->lowerValue(split->inner())->as(); + auto in_id = split->in(); + auto outer_id = split->outer(); + auto inner_id = split->inner(); // Nothing need to be done when mappings for the output axes // already exist. @@ -899,22 +851,17 @@ class UpdateLeafIndices : public IterVisitor { return; } - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto factor = gpu_lower->lowerValue(split->factor()); - index_map_[inner_id] = ir_builder.modExpr(index_map_[in_id], factor); + auto factor = split->factor(); + index_map_[inner_id] = IrBuilder::modExpr(index_map_[in_id], factor); extent_map_[inner_id] = factor; - index_map_[outer_id] = ir_builder.divExpr(index_map_[in_id], factor); - extent_map_[outer_id] = ir_builder.ceilDivExpr(getExtent(in_id), factor); + index_map_[outer_id] = IrBuilder::divExpr(index_map_[in_id], factor); + extent_map_[outer_id] = IrBuilder::ceilDivExpr(getExtent(in_id), factor); } void handle(Merge* merge) override { - const auto gpu_lower = GpuLower::current(); - - auto out_id = gpu_lower->lowerValue(merge->out())->as(); - auto outer_id = - gpu_lower->lowerValue(merge->outer())->as(); - auto inner_id = - gpu_lower->lowerValue(merge->inner())->as(); + auto out_id = merge->out(); + auto outer_id = merge->outer(); + auto inner_id = merge->inner(); // Nothing need to be done when mappings for the output axes // already exist. @@ -927,17 +874,16 @@ class UpdateLeafIndices : public IterVisitor { TORCH_INTERNAL_ASSERT( index_map_.find(inner_id) != index_map_.end(), "Inner ID not found"); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - index_map_[out_id] = ir_builder.mulExpr( + index_map_[out_id] = IrBuilder::mulExpr( index_map_[inner_id], - ir_builder.mulExpr(index_map_[outer_id], getExtent(inner_id))); + IrBuilder::mulExpr(index_map_[outer_id], getExtent(inner_id))); extent_map_[out_id] = - ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id)); + IrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id)); } // return extent_map_[id] if exists, else return id->extent() - kir::Val* getExtent(kir::IterDomain* id) { + Val* getExtent(IterDomain* id) { if (extent_map_.find(id) != extent_map_.end()) { return extent_map_.at(id); } else { @@ -947,25 +893,21 @@ class UpdateLeafIndices : public IterVisitor { private: const TensorDomain* td_; - std::unordered_map index_map_; - std::unordered_map extent_map_; + std::unordered_map index_map_; + std::unordered_map extent_map_; }; // Returns halo-extended extent if id has halo. Otherwise, just // returns id->extent. -kir::Val* getHaloExtentOfRootAxis( - IterDomain* id, - kir::Val* normal_extent = nullptr) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - +Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) { if (normal_extent == nullptr) { - normal_extent = gpu_lower->lowerValue(id->extent()); + normal_extent = id->extent(); } - const auto& halo = gpu_lower->haloInfo().getRootAxisInfo(id); + const auto& halo = GpuLower::current()->haloInfo().getRootAxisInfo(id); if (halo.hasHalo()) { - auto halo_extent = ir_builder.addExpr(normal_extent, halo.width()); + auto halo_extent = + IrBuilder::addExpr(normal_extent, IrBuilder::create(halo.width())); return halo_extent; } else { return normal_extent; @@ -976,10 +918,10 @@ kir::Val* getHaloExtentOfRootAxis( IndexSwizzle::IndexSwizzle( const TensorView* tv, - std::unordered_map initial_index_map, - std::unordered_map extent_map, - std::unordered_set zero_domains, - std::unordered_set zero_merged_in) + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in) : IndexCompute( tv->domain(), std::move(initial_index_map), @@ -996,8 +938,6 @@ void IndexSwizzle::run() { swizzle_type_ == SwizzleType::NoSwizzle || swizzle_type_ == SwizzleType::Transpose, "Invalid swizzle type"); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); if (swizzle_type_ == SwizzleType::Transpose) { // Shifts the second axis by the first axis as ((idx_1 + idx_2) % // ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also @@ -1013,20 +953,16 @@ void IndexSwizzle::run() { IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0); IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1); - kir::IterDomain* id_to_swizzle_i_kir = - gpu_lower->lowerValue(id_to_swizzle_i)->as(); - kir::IterDomain* id_to_swizzle_j_kir = - gpu_lower->lowerValue(id_to_swizzle_j)->as(); - - if (indexMap().find(id_to_swizzle_i_kir) != indexMap().end() && - indexMap().find(id_to_swizzle_j_kir) != indexMap().end()) { - auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i_kir); - auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j_kir); - - auto swizzled_idx = ir_builder.modExpr( - ir_builder.addExpr(idx_to_swizzle_i, idx_to_swizzle_j), - id_to_swizzle_j_kir->extent()); - index_map_[id_to_swizzle_j_kir] = swizzled_idx; + + if (indexMap().find(id_to_swizzle_i) != indexMap().end() && + indexMap().find(id_to_swizzle_j) != indexMap().end()) { + auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i); + auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j); + + auto swizzled_idx = IrBuilder::modExpr( + IrBuilder::addExpr(idx_to_swizzle_i, idx_to_swizzle_j), + id_to_swizzle_j->extent()); + index_map_[id_to_swizzle_j] = swizzled_idx; swizzled_ids_.insert(id_to_swizzle_j); IndexCompute::run(); } @@ -1055,17 +991,15 @@ namespace { // to loop indices as well as a set of loops that do not contribute to // indexing. std::pair< - std::unordered_map, + std::unordered_map, std::unordered_set> indexMapFromTV( const TensorView* tv, const std::vector& loops, - const std::pair& alloc_point, - bool as_consumer) { + kir::ForLoop* alloc_loop, + bool as_consumer, + kir::ForLoop* double_buffer_loop = nullptr) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto alloc_loop = alloc_point.first; bool within_alloc = false; if (alloc_loop == nullptr) { @@ -1076,7 +1010,7 @@ indexMapFromTV( const bool is_shared = tv->getMemoryType() == MemoryType::Shared; const bool is_local = tv->getMemoryType() == MemoryType::Local; - std::unordered_map loop_to_ind_map; + std::unordered_map loop_to_ind_map; // When indexed as a producer, the parallel types of the the // producer domains may not be the same as those of the loops, but @@ -1085,17 +1019,16 @@ indexMapFromTV( // with zero isn't valid. That's only valid when there's a matching // IterDomain in the producer tensor that has the same parallel // type. - auto find_matching_parallel_domain = [tv](kir::IterDomain* id) -> bool { + auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool { const auto gpu_lower = GpuLower::current(); auto it = std::find_if( tv->domain()->domain().begin(), tv->domain()->domain().end(), [&](IterDomain* tv_id) { - auto kir_tv_id = gpu_lower->lowerValue(tv_id)->as(); // Matching is done using the index and loop maps. See // validateParallelize as well. - return gpu_lower->caIndexMap().areMapped(id, kir_tv_id) || - (gpu_lower->caLoopMap().areMapped(id, kir_tv_id) && + return gpu_lower->caIndexMap().areMapped(id, tv_id) || + (gpu_lower->caLoopMap().areMapped(id, tv_id) && ir_utils::derivedFromRootCAAxes(tv, tv_id)); }); if (it == tv->domain()->domain().end()) { @@ -1103,7 +1036,7 @@ indexMapFromTV( } auto corresponding_domain = *it; - return corresponding_domain->getParallelType() == id->parallelType(); + return corresponding_domain->getParallelType() == id->getParallelType(); }; // Track domains that do not contibute to the resulting @@ -1113,7 +1046,7 @@ indexMapFromTV( std::unordered_set zero_loops; for (auto loop : loops) { - kir::Val* idx = nullptr; + Val* idx = nullptr; const auto same_parallel_type = as_consumer || find_matching_parallel_domain(loop->iter_domain()); // See also LoopNestGenerator::pushAlloc. @@ -1123,7 +1056,7 @@ indexMapFromTV( (loop->iter_domain()->isThread() && is_global)) { idx = loop->index(); } else { - idx = ir_builder.zeroVal(); + idx = GpuLower::current()->kernel()->zeroVal(); zero_loops.insert(loop); } } else if ( @@ -1145,7 +1078,7 @@ indexMapFromTV( // parallel type (loop->iter_domain()->isThread() && is_local && same_parallel_type) || loop->vectorize()) { - idx = ir_builder.zeroVal(); + idx = GpuLower::current()->kernel()->zeroVal(); if (!loop->vectorize()) { zero_loops.insert(loop); } @@ -1153,6 +1086,10 @@ indexMapFromTV( idx = loop->index(); } + if (loop == double_buffer_loop) { + idx = IrBuilder::addExpr(idx, GpuLower::current()->kernel()->oneVal()); + } + loop_to_ind_map[loop] = idx; if (!within_alloc && loop == alloc_loop) { @@ -1184,8 +1121,6 @@ void ensureStaticIndexing( within_alloc = true; } - const auto gpu_lower = GpuLower::current(); - for (auto loop : loops) { if (!within_alloc) { if (loop == alloc_loop) { @@ -1193,7 +1128,7 @@ void ensureStaticIndexing( } continue; } - kir::IterDomain* loop_id = loop->iter_domain(); + IterDomain* loop_id = loop->iter_domain(); if (loop->vectorize() || loop_id->isThread()) { continue; } @@ -1203,7 +1138,7 @@ void ensureStaticIndexing( auto it = std::find_if( tv->domain()->domain().begin(), tv->domain()->domain().end(), - [loop_id, gpu_lower, &id_map](IterDomain* id) { + [loop_id, &id_map](IterDomain* id) { if (id->isBroadcast() || id->isReduction() || id->isStride()) { return false; } @@ -1211,8 +1146,7 @@ void ensureStaticIndexing( if (id_replacement != id_map.end()) { id = id_replacement->second; } - auto kir_id = gpu_lower->lowerValue(id)->as(); - return gpu_lower->caLoopMap().areMapped(loop_id, kir_id); + return GpuLower::current()->caLoopMap().areMapped(loop_id, id); }); if (it != tv->domain()->domain().end()) { loop->requireUnroll(); @@ -1260,13 +1194,12 @@ std::unordered_map indexMapReferenceTo( } // namespace -std::vector Index::getGlobalProducerStridedIndices( +std::vector Index::getGlobalProducerStridedIndices( TensorView* producer_tv, const TensorView* consumer_tv, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex"); const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); @@ -1311,9 +1244,12 @@ std::vector Index::getGlobalProducerStridedIndices( } } + kir::ForLoop* db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + // Index into the reference tensor. Reference indexing will handle vectorized // dims where index should be set to 0 - auto ref_compute = getReferenceIndexing(loops, reference_domain); + auto ref_compute = getReferenceIndexing(loops, reference_domain, db_loop); // Forward vectorized IDs to index into producer correctly // We want p_id to be vectorized like consumer just for the indexing, then we @@ -1355,25 +1291,24 @@ std::vector Index::getGlobalProducerStridedIndices( auto root_dom = producer_tv->getMaybeRFactorDomain(); // TODO: Abstract stride logic to reuse with consumer indexing - auto zero = ir_builder.create(0); - std::vector strides(root_dom.size(), nullptr); + std::vector strides(root_dom.size(), nullptr); { int stride_i = 0; for (const auto i : c10::irange(root_dom.size())) { if (root_dom[i]->isReduction() || root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) { - strides[i] = zero; + strides[i] = GpuLower::current()->kernel()->oneVal(); continue; } std::stringstream ss; ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]"; - strides[i] = ir_builder.create(ss.str(), DataType::Int); + strides[i] = IrBuilder::create(ss.str(), DataType::Int); } } TORCH_INTERNAL_ASSERT( root_dom.size() == producer_tv->domain()->contiguity().size()); - kir::Val* cur_contig_stride = ir_builder.create(1); + Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; if (root_dom[dim]->isReduction()) { @@ -1383,14 +1318,12 @@ std::vector Index::getGlobalProducerStridedIndices( continue; } - kir::Val* root_ind = nullptr; - auto kir_root_dom = - gpu_lower->lowerValue(root_dom[dim])->as(); - if (producer_indexing.indexMap().find(kir_root_dom) != + Val* root_ind = nullptr; + if (producer_indexing.indexMap().find(root_dom[dim]) != producer_indexing.indexMap().end()) { - root_ind = producer_indexing.indexMap().at(kir_root_dom); + root_ind = producer_indexing.indexMap().at(root_dom[dim]); } else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) { - root_ind = zero; + root_ind = GpuLower::current()->kernel()->zeroVal(); } TORCH_INTERNAL_ASSERT( @@ -1410,12 +1343,12 @@ std::vector Index::getGlobalProducerStridedIndices( // by extent of this dimension auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); cur_contig_stride = - ir_builder.mulExpr(cur_contig_stride, root_dim_extent); + IrBuilder::mulExpr(cur_contig_stride, root_dim_extent); } else { // If non contiguous dimension, keep local stride information, set cur // stride to local stride * local raw extent auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); - cur_contig_stride = ir_builder.mulExpr(strides[dim], root_dim_extent); + cur_contig_stride = IrBuilder::mulExpr(strides[dim], root_dim_extent); } } @@ -1423,7 +1356,8 @@ std::vector Index::getGlobalProducerStridedIndices( loops.empty() ? nullptr : loops.back()->vectorize_shift(); // Global striding - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // If the domain is derived from a trivial reduction, no indexing // to create. @@ -1434,20 +1368,17 @@ std::vector Index::getGlobalProducerStridedIndices( continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - producer_indexing.indexMap().find(kir_root_dom_i) != + producer_indexing.indexMap().find(root_dom[i]) != producer_indexing.indexMap().end(), "Couldn't find root mapping for TV", producer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - auto root_ind = producer_indexing.indexMap().at(kir_root_dom_i); + auto root_ind = producer_indexing.indexMap().at(root_dom[i]); root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv); @@ -1465,9 +1396,9 @@ std::vector Index::getGlobalProducerStridedIndices( if (root_ind->isZeroInt()) { continue; } else { - auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]); + auto strided_ind = IrBuilder::mulExpr(root_ind, strides[i]); if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { - strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift); + strided_inds[i] = IrBuilder::addExpr(strided_ind, vectorize_shift); } else { strided_inds[i] = strided_ind; } @@ -1478,12 +1409,11 @@ std::vector Index::getGlobalProducerStridedIndices( } // Producer index for either shared or local memory -std::vector Index::getNonGlobalProducerStridedIndices( +std::vector Index::getNonGlobalProducerStridedIndices( TensorView* producer_tv, const TensorView* consumer_tv, const std::vector& loops) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); @@ -1526,31 +1456,35 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } + kir::ForLoop* consumer_db_loop = + gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + // Find allocation point of producer relative to loop nests. P2C map is // required because producer was replayed as consumer, so we can't use the // regular compute at maps to line up its iter domains with the for loops. - auto alloc_point = - loop_utils::getAllocPoint(producer_tv, loops, p2c_alloc_map, true); - std::unordered_map loop_to_ind_map; + auto alloc_info = + loop_utils::getAllocInformation(producer_tv, loops, p2c_alloc_map, true); + std::unordered_map loop_to_ind_map; std::unordered_set zero_loops; - std::tie(loop_to_ind_map, zero_loops) = - indexMapFromTV(producer_tv, loops, alloc_point, false); + std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( + producer_tv, loops, alloc_info.init_for_loop, false, consumer_db_loop); - ensureStaticIndexing(producer_tv, alloc_point.first, loops, p2c_alloc_map); + ensureStaticIndexing( + producer_tv, alloc_info.init_for_loop, loops, p2c_alloc_map); // Map loop nests to indicies, zeroing out those not used due to locality of // memory - std::unordered_map ref_id_to_ind_map; + std::unordered_map ref_id_to_ind_map; // Track which domains are not used - std::unordered_set ref_zero_domains; + std::unordered_set ref_zero_domains; // Due to rfactor/initialization reference_domain may be bigger than loop nest // structure, ignore IterDomains that aren't present in the loop nest when // indexing reference. TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); for (const auto loop_i : c10::irange(loops.size())) { - auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i)) - ->as(); + auto ref_axis = reference_domain->axis(loop_i); ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]]; if (zero_loops.count(loops[loop_i]) > 0) { ref_zero_domains.insert(ref_axis); @@ -1677,8 +1611,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( } // Already an entry for this root domain, continue - if (index_map.find(gpu_lower->lowerValue(root_id)->as()) != - index_map.end()) { + if (index_map.find(root_id) != index_map.end()) { continue; } @@ -1690,25 +1623,23 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { if (skip_indexing.count(root_dom[i])) { continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_i) != index_map.end(), + index_map.find(root_dom[i]) != index_map.end(), "Couldn't find root mapping for TV", producer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - auto root_ind_i = index_map.at(kir_root_dom_i); + auto root_ind_i = index_map.at(root_dom[i]); root_ind_i = getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv); @@ -1729,17 +1660,14 @@ std::vector Index::getNonGlobalProducerStridedIndices( } // Compute striding for this index. - kir::Val* stride = nullptr; + Val* stride = nullptr; for (const auto j : c10::irange(i + 1, root_dom.size())) { if (skip_indexing.count(root_dom[j])) { continue; } - auto kir_root_dom_j = - gpu_lower->lowerValue(root_dom[j])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_j) != index_map.end(), + index_map.find(root_dom[j]) != index_map.end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", @@ -1747,37 +1675,49 @@ std::vector Index::getNonGlobalProducerStridedIndices( " id: ", root_dom[i]); - auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end() - ? kir_root_dom_j->extent() - : extent_map.at(kir_root_dom_j); + auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() + ? root_dom[j]->extent() + : extent_map.at(root_dom[j]); root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j); - if (zero_domain_map.count(kir_root_dom_j) == 0) { + if (zero_domain_map.count(root_dom[j]) == 0) { if (stride == nullptr) { stride = root_ext_j; } else { - stride = ir_builder.mulExpr(stride, root_ext_j); + stride = IrBuilder::mulExpr(stride, root_ext_j); } } } if (stride != nullptr) { - strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride); + strided_inds[i] = IrBuilder::mulExpr(root_ind_i, stride); } else { strided_inds[i] = root_ind_i; } } + if (producer_tv->isDoubleBuffered()) { + auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + producer_tv, loops, true); + if (db_loop != nullptr) { + auto db_switch_index = + IrBuilder::modExpr(db_loop->index(), IrBuilder::create(2)); + auto original_alloc_size = + gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); + auto db_strided_index = + IrBuilder::mulExpr(db_switch_index, original_alloc_size); + strided_inds.push_back(db_strided_index); + } + } return strided_inds; } -std::vector Index::getGlobalConsumerStridedIndices( +std::vector Index::getGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); @@ -1813,26 +1753,27 @@ std::vector Index::getGlobalConsumerStridedIndices( auto root_dom = consumer_tv->getMaybeRFactorDomain(); // TODO: Abstract stride logic to reuse with producer indexing - auto zero = ir_builder.zeroVal(); - std::vector strides(root_dom.size(), zero); + std::vector strides( + root_dom.size(), GpuLower::current()->kernel()->oneVal()); { int stride_i = 0; for (const auto i : c10::irange(root_dom.size())) { if (root_dom[i]->isReduction() || root_dom[i]->getIterType() == IterType::BroadcastWithoutStride || root_dom[i]->isStride()) { - strides[i] = zero; + strides[i] = GpuLower::current()->kernel()->oneVal(); continue; } std::stringstream ss; ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]"; - strides[i] = ir_builder.create(ss.str(), DataType::Int); + strides[i] = + SimplifyingIrBuilder::create(ss.str(), DataType::Int); } } TORCH_INTERNAL_ASSERT( root_dom.size() == consumer_tv->domain()->contiguity().size()); - kir::Val* cur_contig_stride = ir_builder.oneVal(); + Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) { @@ -1842,14 +1783,12 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } - kir::Val* root_ind = nullptr; - auto kir_root_dom = - gpu_lower->lowerValue(root_dom[dim])->as(); - if (consumer_indexing.indexMap().find(kir_root_dom) != + Val* root_ind = nullptr; + if (consumer_indexing.indexMap().find(root_dom[dim]) != consumer_indexing.indexMap().end()) { - root_ind = consumer_indexing.indexMap().at(kir_root_dom); + root_ind = consumer_indexing.indexMap().at(root_dom[dim]); } else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) { - root_ind = zero; + root_ind = GpuLower::current()->kernel()->zeroVal(); } TORCH_INTERNAL_ASSERT( @@ -1869,11 +1808,11 @@ std::vector Index::getGlobalConsumerStridedIndices( // by extent of this dimension auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); cur_contig_stride = - ir_builder.mulExpr(cur_contig_stride, root_dim_extent); + SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent); } else { // If non contiguous dimension, keep local stride information, set cur // stride to local stride * local raw extent - cur_contig_stride = ir_builder.mulExpr( + cur_contig_stride = SimplifyingIrBuilder::mulExpr( strides[dim], getHaloExtentOfRootAxis(root_dom[dim])); } } @@ -1882,7 +1821,8 @@ std::vector Index::getGlobalConsumerStridedIndices( loops.empty() ? nullptr : loops.back()->vectorize_shift(); // Global striding - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // See a comment in indexing to root domains in getGlobalProducerIndex. if (root_dom[i]->isReduction() || @@ -1893,71 +1833,70 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - consumer_indexing.indexMap().find(kir_root_dom_i) != + consumer_indexing.indexMap().find(root_dom[i]) != consumer_indexing.indexMap().end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - auto root_ind = consumer_indexing.indexMap().at(kir_root_dom_i); + auto root_ind = consumer_indexing.indexMap().at(root_dom[i]); - root_ind = ir_builder.addExpr( - root_ind, getGlobalConsumerOffsetWithPartialSplit(kir_root_dom_i)); + root_ind = SimplifyingIrBuilder::addExpr( + root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i])); if (root_ind->isZeroInt()) { continue; } else { - auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]); + auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { - strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift); + strided_inds[i] = + SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); } else { strided_inds[i] = strided_ind; } } } + TORCH_INTERNAL_ASSERT( + strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); + return strided_inds; } // Consumer index for either shared or local memory -std::vector Index::getNonGlobalConsumerStridedIndices( +std::vector Index::getNonGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); auto reference_domain = reference.domain; auto reference_id_map = reference.concrete_to_id; - auto alloc_point = loop_utils::getAllocPoint(consumer_tv, loops); - std::unordered_map loop_to_ind_map; + auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops); + std::unordered_map loop_to_ind_map; std::unordered_set zero_loops; std::tie(loop_to_ind_map, zero_loops) = - indexMapFromTV(consumer_tv, loops, alloc_point, true); + indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true); - ensureStaticIndexing(consumer_tv, alloc_point.first, loops); + ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops); // Map loop nests to indicies, zeroing out those not used due to locality of // memory - std::unordered_map ref_id_to_ind_map; - std::unordered_set ref_zero_domains; + std::unordered_map ref_id_to_ind_map; + std::unordered_set ref_zero_domains; // Due to rfactor/initialization reference_domain may be bigger than loop nest // structure, ignore IterDomains that aren't present in the loop nest when // indexing reference. TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); for (const auto loop_i : c10::irange(loops.size())) { - auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i)) - ->as(); + auto ref_axis = reference_domain->axis(loop_i); ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]]; if (zero_loops.count(loops[loop_i]) > 0) { ref_zero_domains.insert(ref_axis); @@ -2022,7 +1961,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( // Indices should now be mapped onto IterDomains in consumer, so just grab // and use them. auto root_dom = consumer_tv->getMaybeRFactorDomain(); - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) || @@ -2030,25 +1970,22 @@ std::vector Index::getNonGlobalConsumerStridedIndices( continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_i) != index_map.end(), + index_map.find(root_dom[i]) != index_map.end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - const auto root_ind_i = index_map.at(kir_root_dom_i); + const auto root_ind_i = index_map.at(root_dom[i]); if (root_ind_i->isZeroInt()) { continue; } // Compute striding for this index. - kir::Val* stride = nullptr; + Val* stride = nullptr; for (const auto j : c10::irange(i + 1, root_dom.size())) { if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() || gpu_lower->trivialReductionInfo().isDerived(root_dom[j]) || @@ -2056,11 +1993,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( continue; } - auto kir_root_dom_j = - gpu_lower->lowerValue(root_dom[j])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_j) != index_map.end(), + index_map.find(root_dom[j]) != index_map.end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", @@ -2068,45 +2002,67 @@ std::vector Index::getNonGlobalConsumerStridedIndices( " id: ", root_dom[i]); - auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end() - ? kir_root_dom_j->extent() - : extent_map.at(kir_root_dom_j); + auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() + ? root_dom[j]->extent() + : extent_map.at(root_dom[j]); root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j); - if (zero_domain_map.count(kir_root_dom_j) == 0) { + if (zero_domain_map.count(root_dom[j]) == 0) { if (stride == nullptr) { stride = root_ext_j; } else { - stride = ir_builder.mulExpr(stride, root_ext_j); + stride = IrBuilder::mulExpr(stride, root_ext_j); } } } if (stride != nullptr) { - strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride); + strided_inds[i] = IrBuilder::mulExpr(root_ind_i, stride); } else { strided_inds[i] = root_ind_i; } } + // This check was originally done in getConsumerStridedIndices, but + // the number of strided index values depends on the loop where the + // consumer tensor is located. If it's double buffered and not in + // the prologue loop, strided_inds ends up having one more + // index, so it's just much simpler to check here before adding the + // additional index for double buffering. + TORCH_INTERNAL_ASSERT( + strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); + + if (consumer_tv->isDoubleBuffered()) { + auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + if (db_loop != nullptr) { + auto db_switch_index = IrBuilder::subExpr( + gpu_lower->kernel()->oneVal(), + IrBuilder::modExpr(db_loop->index(), IrBuilder::create(2))); + auto original_alloc_size = + gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv); + auto db_strided_index = + IrBuilder::mulExpr(db_switch_index, original_alloc_size); + strided_inds.push_back(db_strided_index); + } + } + return strided_inds; } -std::vector Index::getProducerStridedIndices( +std::vector Index::getProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices"); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - if (producer->domain()->noReductions().size() == 0) { - return std::vector( - producer->getMaybeRFactorDomain().size(), ir_builder.zeroVal()); + return std::vector( + producer->getMaybeRFactorDomain().size(), + GpuLower::current()->kernel()->zeroVal()); } - std::vector strided_indices; + std::vector strided_indices; if (producer->getMemoryType() == MemoryType::Global) { strided_indices = getGlobalProducerStridedIndices(producer, consumer, loops); @@ -2116,7 +2072,9 @@ std::vector Index::getProducerStridedIndices( } TORCH_INTERNAL_ASSERT( - strided_indices.size() == producer->getMaybeRFactorDomain().size()); + strided_indices.size() == + producer->getMaybeRFactorDomain().size() + + (producer->isDoubleBuffered() ? 1 : 0)); return strided_indices; } @@ -2126,35 +2084,27 @@ kir::TensorIndex* Index::getProducerIndex( TensorView* producer, const TensorView* consumer, const std::vector& loops) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto strided_indices = getProducerStridedIndices(producer, consumer, loops); - return ir_builder.create(producer, strided_indices); + return IrBuilder::create(producer, strided_indices); } -std::vector Index::getConsumerStridedIndices( +std::vector Index::getConsumerStridedIndices( const TensorView* consumer, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices"); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - if (consumer->domain()->noReductions().size() == 0) { - return std::vector( - consumer->getMaybeRFactorDomain().size(), ir_builder.zeroVal()); + return std::vector( + consumer->getMaybeRFactorDomain().size(), + GpuLower::current()->kernel()->zeroVal()); } - std::vector strided_indices; + std::vector strided_indices; if (consumer->getMemoryType() == MemoryType::Global) { strided_indices = getGlobalConsumerStridedIndices(consumer, loops); } else { strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops); } - TORCH_INTERNAL_ASSERT( - strided_indices.size() == consumer->getMaybeRFactorDomain().size()); - return strided_indices; } @@ -2162,11 +2112,8 @@ std::vector Index::getConsumerStridedIndices( kir::TensorIndex* Index::getConsumerIndex( const TensorView* consumer, const std::vector& loops) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto strided_indices = getConsumerStridedIndices(consumer, loops); - return ir_builder.create(consumer, strided_indices); + return IrBuilder::create(consumer, strided_indices); } namespace { @@ -2184,37 +2131,19 @@ struct PredicateDomainInfo { bool is_non_divisible_split = false; }; -// Find iteration domains in the history of reference comprised only of -// merge operations. Only return iteration domains that are subsequently fed -// into a split, or are in the provided domain. In other words, we don't want to -// return every IterDomain that's contiguous, just the one closest to the -// leaves. Predicates are not associated with physical memory so we can treat -// all of them as contiguous merges. +// Find iteration domains in the history of a consumer to predicate comprised +// only of merge operations. Only return iteration domains that are subsequently +// fed into a split, or are in the provided domain. In other words, we don't +// want to return every IterDomain that's contiguous, just the one closest to +// the leaves. Predicates are not associated with physical memory so we can +// treat all of them as contiguous merges. std::vector getPredicateContigIds( - const ReferenceTensor& reference, - TensorView* consumer_tv, - const std::unordered_map& ref_2_consumer) { + TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - std::vector reference_predicated_root_domain; - for (const auto consumer_root : consumer_tv->getRootDomain()) { - if (consumer_root->isBroadcast()) { - continue; - } - auto consumer_root_concrete = - gpu_lower->caIndexMap().getConcreteMappedID(consumer_root); - auto it = reference.concrete_to_id.find(consumer_root_concrete); - // When initializing a reduction buffer, the reduction axis - // doesn't have a loop, so the reference tensor doesn't have a - // mapped domain. The reduction axis can be safely ignored. - if (it == reference.concrete_to_id.end()) { - continue; - } - auto reference_root = it->second; - reference_predicated_root_domain.emplace_back(reference_root); - } + const auto& consumer_root_domain = consumer_tv->getRootDomain(); - std::vector contiguous_ids = reference_predicated_root_domain; + std::vector contiguous_ids = consumer_root_domain; if (contiguous_ids.empty()) { return std::vector(); @@ -2227,20 +2156,24 @@ std::vector getPredicateContigIds( // about halo to do correct predication, so they must be excluded. std::unordered_set excluded_ids; - for (auto reference_predicated_id : reference_predicated_root_domain) { - if (GpuLower::current() - ->haloInfo() - .getRootAxisInfo(reference_predicated_id) - .hasHalo()) { + for (auto consumer_root_id : consumer_root_domain) { + if (gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id).hasHalo()) { + excluded_ids.insert(consumer_root_id); continue; } - auto it = ref_2_consumer.find(reference_predicated_id); - if (it == ref_2_consumer.end()) { + if (consumer_root_id->maybePartial()) { + excluded_ids.insert(consumer_root_id); continue; } - auto consumer_root_id = it->second; - if (consumer_root_id->maybePartial()) { - excluded_ids.insert(reference_predicated_id); + // When consumer_root_id is a broadcast domain, do not allow contig + // predication as the merged output is not mapped with the + // reference unless the concrete domain is also a broadcast + // domain. + if (consumer_root_id->isBroadcast() && + !gpu_lower->caLoopMap() + .getConcreteMappedID(consumer_root_id) + ->isBroadcast()) { + excluded_ids.insert(consumer_root_id); continue; } // Shifted or gathered axes need to be predicated at the root domain @@ -2252,15 +2185,16 @@ std::vector getPredicateContigIds( auto consumer_root_pos = consumer_tv->domain()->rootPosOf(consumer_root_id); if ((shift_expr && shift_expr->offset(consumer_root_pos) != 0) || (gather_expr && consumer_root_pos < gather_expr->windowShape().size() && - !gather_expr->windowShape().at(consumer_root_pos)->isOneInt())) { - excluded_ids.insert(reference_predicated_id); + gather_expr->windowShape().at(consumer_root_pos) != 1)) { + excluded_ids.insert(consumer_root_id); } } // Run through iteration domain history - auto exprs = ExprSort::getExprs( + auto exprs = StmtSort::getExprs( consumer_tv->fusion(), - {reference.domain->domain().begin(), reference.domain->domain().end()}); + {consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end()}); for (auto expr : exprs) { // If not a merge, output is not contiguous @@ -2296,8 +2230,7 @@ std::vector getPredicateContigIds( // reference_predicated_root_domain. auto contig_root_vals = IterVisitor::getInputsTo( {contig_id}, - {reference_predicated_root_domain.begin(), - reference_predicated_root_domain.end()}); + {consumer_root_domain.begin(), consumer_root_domain.end()}); auto contig_root_ids = ir_utils::filterByType(contig_root_vals); PredicateDomainInfo contig_id_info; contig_id_info.id = contig_id; @@ -2312,8 +2245,7 @@ IterDomain* getMappedReferenceDomain( IterDomain* id, const ReferenceTensor& reference) { // Partially overlaps with getPredicateContigIds() - const auto gpu_lower = GpuLower::current(); - auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(id); + auto concrete_id = GpuLower::current()->caIndexMap().getConcreteMappedID(id); auto it = reference.concrete_to_id.find(concrete_id); if (it == reference.concrete_to_id.end()) { return nullptr; @@ -2321,9 +2253,8 @@ IterDomain* getMappedReferenceDomain( return it->second; } -std::vector getNonDivisibleReferenceDomainsToPredicate( - TensorView* consumer_tv, - const ReferenceTensor& reference) { +std::vector getNonDivisibleConsumerDomainsToPredicate( + TensorView* consumer_tv) { const auto& non_divisible_split_info = GpuLower::current()->nonDivisibleSplitInfo(); @@ -2337,11 +2268,7 @@ std::vector getNonDivisibleReferenceDomainsToPredicate( const auto& splits_to_predicate = it->second; for (auto split : splits_to_predicate) { - auto ref_id = getMappedReferenceDomain(split->in(), reference); - if (ref_id == nullptr) { - continue; - } - PredicateDomainInfo info{ref_id, {ref_id}, true}; + PredicateDomainInfo info{split->in(), {split->in()}, true}; pred_info_vec.emplace_back(info); } @@ -2352,9 +2279,8 @@ bool needsPadding(TensorView* tv) { auto shift_expr = dynamic_cast(tv->definition()); auto gather_expr = dynamic_cast(tv->definition()); - // Padding is only necessary for padded shift and - // gather - return (shift_expr != nullptr && shift_expr->pad()) || gather_expr != nullptr; + return (shift_expr != nullptr && shift_expr->hasPadding()) || + (gather_expr != nullptr && gather_expr->hasPadding()); } // Get an additional offset of a stop index when building a predicate @@ -2364,11 +2290,10 @@ bool needsPadding(TensorView* tv) { // compared with each other by just looking at the additional offsets. // // consumer_root_id: the domain for which a stop predicate is being built. -kir::Val* getUnswitchStopOffset( +int getUnswitchStopOffset( IterDomain* consumer_root_id, TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id); @@ -2376,7 +2301,7 @@ kir::Val* getUnswitchStopOffset( // If the consumer root domain to predicate does not have halo, no // adjustment is required. if (!halo_info.hasHalo()) { - return ir_builder.zeroVal(); + return 0; } // Find if this contig_id is used in the unswitched domains @@ -2400,22 +2325,14 @@ kir::Val* getUnswitchStopOffset( })) { return halo_info.width(); } else { - return ir_builder.zeroVal(); + return 0; } } -// Get offsets for the start and stop predicates. Similar to the -// gather case, but it's a little simpler as it does not (yet) -// dynamic shifting. -void adjustStartAndStopOffsetsForShift( - std::vector& start_offsets, - std::vector& stop_offsets, +std::pair getStartAndStopOffsetsForShift( TensorView* consumer_tv, IterDomain* consumer_id, bool padding_predicate) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - TORCH_INTERNAL_ASSERT(consumer_id != nullptr); auto shift_expr = dynamic_cast(consumer_tv->definition()); @@ -2423,105 +2340,124 @@ void adjustStartAndStopOffsetsForShift( // Adjustment is not necessary if not shift. // Even so, padding predicate does not need any adjustment. if (shift_expr == nullptr || padding_predicate) { - return; + return { + GpuLower::current()->kernel()->zeroVal(), + GpuLower::current()->kernel()->zeroVal()}; } const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id); - // Assume this adjustment is done first, so start and stop offsets - // just contain zeroVal. - TORCH_INTERNAL_ASSERT( - start_offsets.size() == 1 && start_offsets[0]->isZeroInt() && - stop_offsets.size() == 1 && stop_offsets[0]->isZeroInt()); - start_offsets.clear(); - stop_offsets.clear(); - - // The consumer offset is zero. - auto consumer_offset = 0; - // The producer offset is based off the consumer offset. - auto producer_offset = 0; - - // When the shift operation is not padded, the start and stop positions of the - // consumer axis, i.e., consumer_id->start and - // consumer_id->stop_ofset, are adjusted accordingly, which includes - // the effect of the shift offset, so using the consumer offset is - // sufficient as the only predicate is sufficient. - - if (shift_expr->pad()) { - // Positive shift offset means shifting the input tensor to the - // positive direction, so the producer offset becomes negative. - auto shift_offset = shift_expr->offset(root_axis_pos); - producer_offset = -shift_offset; - } - - // Since shift doesn't allow dynamic offsets, we can statically - // choose more restrictive offsets between the producer and consumer - // offsets. The start predicate uses greater-than, so using the - // smaller offset is sufficient. Similarly, for the stop predicate, - // using the larger offset is sufficient. - auto start_offset = std::min(consumer_offset, producer_offset); - auto stop_offset = std::max(consumer_offset, producer_offset); - - start_offsets.push_back(ir_builder.create(start_offset)); - stop_offsets.push_back(ir_builder.create(stop_offset)); + // The first or last N elements, where N is the padding width, + // correspond to the padding predicate. + + const auto shift_offset = shift_expr->offset(root_axis_pos); + const auto pad_width = shift_expr->padWidth().at(root_axis_pos); + + int start_offset = 0; + int stop_offset = 0; + + if (shift_offset > 0) { + start_offset = -pad_width; + } else if (shift_offset < 0) { + stop_offset = pad_width; + } + + return { + IrBuilder::create(start_offset), + IrBuilder::create(stop_offset)}; } -// Get offsets for the start and stop predicates. There can be two -// offsets because the shift offset is determined by a loop index. -void adjustStartAndStopOffsetsForGather( - std::vector& start_offsets, - std::vector& stop_offsets, +std::pair getStartAndStopOffsetsForGather( TensorView* consumer_tv, IterDomain* consumer_id, - const ReferenceTensor& reference, - const std::unordered_map& ref_start_index_map, - const std::unordered_map& ref_stop_index_map, + const std::unordered_map& ref_start_index_map, + const std::unordered_map& ref_stop_index_map, bool padding_predicate) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - TORCH_INTERNAL_ASSERT(consumer_id != nullptr); // Adjustment is not necessary if not gather. Even so, padding // predicate does not need any adjustment. if (!consumer_tv->definition()->isA() || padding_predicate) { - return; + return { + GpuLower::current()->kernel()->zeroVal(), + GpuLower::current()->kernel()->zeroVal()}; } const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id); - // Assume this adjustment is done first, so start and stop offsets - // just contain zeroVal. - TORCH_INTERNAL_ASSERT( - start_offsets.size() == 1 && start_offsets[0]->isZeroInt() && - stop_offsets.size() == 1 && stop_offsets[0]->isZeroInt()); - start_offsets.clear(); - stop_offsets.clear(); - auto producer_start_offset = getProducerOffsetWithGather( - root_axis_pos, - consumer_tv, - reference.concrete_to_id, - ref_start_index_map); + root_axis_pos, consumer_tv, ref_start_index_map); auto producer_stop_offset = getProducerOffsetWithGather( - root_axis_pos, consumer_tv, reference.concrete_to_id, ref_stop_index_map); + root_axis_pos, consumer_tv, ref_stop_index_map); - // The producer and consumer accesses must be predicated as it is - // not statically determined which is more restrictive. + auto consumer_start_offset = GpuLower::current()->kernel()->zeroVal(); + auto consumer_stop_offset = GpuLower::current()->kernel()->zeroVal(); - // Consumer offsets are just zero. - start_offsets.push_back(ir_builder.zeroVal()); - stop_offsets.push_back(ir_builder.zeroVal()); + if (producer_start_offset->isZeroInt() && producer_stop_offset->isZeroInt()) { + return {consumer_start_offset, consumer_stop_offset}; + } + + Val* start_offset = nullptr; + Val* stop_offset = nullptr; + + // In the normal case, take the minimum of the start and the + // maximum of the stop offsets. If there's no padding, the producer + // offset must be always larger than the consumer + // offset. So, the consumer and produce offsets can be always used + // for the start and stop offsets, respectively. + const auto pad_left = + consumer_tv->definition()->as()->padWidth()[root_axis_pos][0]; + const auto pad_right = + consumer_tv->definition()->as()->padWidth()[root_axis_pos][1]; + const auto window_size = + consumer_tv->definition()->as()->windowShape()[root_axis_pos]; - // Adds producer offsets if they are not zero. - if (!producer_start_offset->isZeroInt()) { - start_offsets.push_back(producer_start_offset); + // consumer index: index + // producer index: index + window_index - pad_left + // + // consumer extent: ext + // producer extent: ext + window_size - 1 - pad_left - pad_right + // + // consumer stop pred: index < ext + // producer stop pred: index + window_index - pad_left < ext + window_size - 1 + // - pad_left - pad_right + // -> index + window_index - pad_left - (window_size - 1 - + // pad_left - pad_right) < ext + // -> index + window_index - (window_size - 1 - pad_right) < + // ext + // + // consumer start pred: index >= 0 + // producer start pred: index + window_index - pad_left >= 0 + + const auto producer_ext_adj = window_size - 1 - pad_left - pad_right; + producer_stop_offset = SimplifyingIrBuilder::subExpr( + producer_stop_offset, + SimplifyingIrBuilder::create(producer_ext_adj)); + + // As commented above, when pad_left is zero, the consumer predicate + // is always more restrictive than the producer predicate. + if (pad_left == 0) { + start_offset = consumer_start_offset; + } else { + start_offset = SimplifyingIrBuilder::minExpr( + consumer_start_offset, producer_start_offset); } - if (!producer_stop_offset->isZeroInt()) { - stop_offsets.push_back(producer_stop_offset); + // As commented above, when pad_right is zero, the consumer + // predicate is always more restrictive than the producer + // predicate. + if (pad_right == 0) { + stop_offset = consumer_stop_offset; + } else { + stop_offset = SimplifyingIrBuilder::maxExpr( + consumer_stop_offset, producer_stop_offset); } + + TORCH_INTERNAL_ASSERT(start_offset != nullptr); + TORCH_INTERNAL_ASSERT(stop_offset != nullptr); + + return {start_offset, stop_offset}; } // Get the start and stop limit offsets that define the valid range to @@ -2530,18 +2466,16 @@ void adjustStartAndStopOffsetsForGather( // stop that's different from extent. Also, when IterDomain has halo, // the actual offsets of the logical start and stop positions are // shifted. -std::pair getStartAndStopLimitOffsets( +std::pair getStartAndStopLimitOffsets( IterDomain* consumer_id, bool padding_predicate, bool non_divisible_pred) { const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); TORCH_INTERNAL_ASSERT(consumer_id != nullptr); - kir::Val* start_limit = gpu_lower->lowerValue(consumer_id->start()); - kir::Val* stop_limit = - ir_builder.negExpr(gpu_lower->lowerValue(consumer_id->stopOffset())); + Val* start_limit = consumer_id->start(); + Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset()); if (!non_divisible_pred) { AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_id); @@ -2554,12 +2488,14 @@ std::pair getStartAndStopLimitOffsets( // [0, left halo)[start_limit, stop_limit)[0, right halo) // if (!padding_predicate) { - start_limit = ir_builder.addExpr(start_limit, halo_info.width(0)); - stop_limit = ir_builder.addExpr(stop_limit, halo_info.width(0)); + start_limit = + SimplifyingIrBuilder::addExpr(start_limit, halo_info.width(0)); + stop_limit = + SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width(0)); } else { // In case of the padding predicate, the whole range, including both left // and right halo regions, is computed. - stop_limit = ir_builder.addExpr(stop_limit, halo_info.width()); + stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width()); } } else { // For non-divisible predicates, the index must be predicated such @@ -2568,28 +2504,26 @@ std::pair getStartAndStopLimitOffsets( // isn't a root domain. if (gpu_lower->haloInfo().hasHaloWidth(consumer_id)) { auto halo = gpu_lower->haloInfo().getHaloWidth(consumer_id); - stop_limit = ir_builder.addExpr(stop_limit, halo); + stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo); } } return {start_limit, stop_limit}; } -// Return an index map for a predicate reference tensor. Two different +// Return an IndexCompute for a predicate reference tensor. Two different // maps are used when generating predicates for unswitched expressions // as start and stop conditions need to use different loop-to-index // mappings. -std::unordered_map getPredicateReferenceIndexing( +auto getPredicateReferenceIndexing( const std::vector& loops, const ReferenceTensor& reference, kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, bool start) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - auto reference_domain = reference.domain; - std::unordered_map loop_to_ind_map; + std::unordered_map loop_to_ind_map; std::transform( loops.begin(), @@ -2606,7 +2540,7 @@ std::unordered_map getPredicateReferenceIndexing( // vectorized loop should be like this. bool vectorized_pred = - unswitch_or_vec_loop->iter_domain()->parallelType() == + unswitch_or_vec_loop->iter_domain()->getParallelType() == ParallelType::Vectorize; TORCH_INTERNAL_ASSERT( @@ -2614,12 +2548,11 @@ std::unordered_map getPredicateReferenceIndexing( "Invalid reference generated."); bool within_unswitch = false; - const auto one = ir_builder.oneVal(); for (const auto loop_i : c10::irange(loops.size())) { auto loop = loops[loop_i]; auto loop_id = loop->iter_domain(); - auto loop_pt = loop_id->parallelType(); + auto loop_pt = loop_id->getParallelType(); auto ref_id = reference_domain->axis(loop_i); if (loop == unswitch_or_vec_loop) { @@ -2668,20 +2601,21 @@ std::unordered_map getPredicateReferenceIndexing( if (loop->stop() == loop_id->extent()) { loop_to_ind_map[loop] = loop->start(); } else if (start) { - loop_to_ind_map[loop] = ir_builder.zeroVal(); + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); } else { // Note that the parallel dimension is used rather than // loop-stop(). See the above comment. - loop_to_ind_map[loop] = ir_builder.subExpr( - gpu_lower->parallelDimensionMap().get(loop_pt), - ir_builder.create(1)); + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + GpuLower::current()->parallelDimensionMap().get(loop_pt), + GpuLower::current()->kernel()->zeroVal()); } } else if (start) { - loop_to_ind_map[loop] = ir_builder.zeroVal(); + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); } else { // Similar to the above, loop_id()->extent() is // used here instead of loop->stop(). See the above comment. - loop_to_ind_map[loop] = ir_builder.subExpr(loop_id->extent(), one); + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + loop_id->extent(), GpuLower::current()->kernel()->oneVal()); } } @@ -2693,9 +2627,27 @@ std::unordered_map getPredicateReferenceIndexing( } } + if (double_buffer_axis != nullptr) { + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + double_buffer_axis, loops, true); + if (db_loop != nullptr) { + auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop); + TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end()); + auto cur_index = loop_to_ind_map_it->second; + // if cur_index is not the same as the index of db_loop, it must + // be true that that index has been modified to support + // unswitch. In that case, it is not necessary to move ahead the + // index for double buffering. + if (cur_index == db_loop->index()) { + loop_to_ind_map[db_loop] = IrBuilder::addExpr( + cur_index, GpuLower::current()->kernel()->oneVal()); + } + } + } + // Add magic zero to a loop pretty far inside in indexing - kir::IterDomain* magic_zero_loop = nullptr; - std::unordered_map ref_id_to_ind_map; + IterDomain* magic_zero_loop = nullptr; + std::unordered_map ref_id_to_ind_map; // Due to rfactor/initialization reference_domain may be bigger than loop nest // structure TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); @@ -2703,19 +2655,19 @@ std::unordered_map getPredicateReferenceIndexing( auto loop = loops[loop_i]; auto ind = loop_to_ind_map[loops[loop_i]]; auto ref_axis = reference_domain->axis(loop_i); - auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as(); if (Index::protectWithMagicZero(loop, ref_axis, ind)) { - magic_zero_loop = kir_ref_axis; + magic_zero_loop = ref_axis; } - ref_id_to_ind_map[kir_ref_axis] = loop_to_ind_map[loop]; + ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loop]; } if (ref_id_to_ind_map.count(magic_zero_loop)) { auto& ind = ref_id_to_ind_map[magic_zero_loop]; if (!ind->isConstScalar()) { - ind = ir_builder.addExpr(ind, ir_builder.magicZeroVal()); + ind = SimplifyingIrBuilder::addExpr( + ind, GpuLower::current()->kernel()->magicZeroVal()); } } @@ -2729,7 +2681,7 @@ std::unordered_map getPredicateReferenceIndexing( ref_self_map.insert({id, id}); }); - std::unordered_map reference_halo_extent_map = + std::unordered_map reference_halo_extent_map = getReferenceHaloExtentMap(reference, ref_self_map); // Index into the reference tensor @@ -2741,64 +2693,55 @@ std::unordered_map getPredicateReferenceIndexing( {}, reference_halo_extent_map); - return index_compute.indexMap(); + return index_compute; } // Get the offsets for the start and stop predicates. The offsets // are to be added to the index. -std::pair, std::vector> getStartAndStopOffsets( +std::pair getStartAndStopOffsets( IterDomain* consumer_id, TensorView* consumer_tv, const ReferenceTensor& reference, - const std::unordered_map& ref_start_index_map, - const std::unordered_map& ref_stop_index_map, + const std::unordered_map& consumer_start_index_map, + const std::unordered_map& consumer_stop_index_map, bool padding_predicate, bool unswitch, bool non_divisible_pred) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - // By default, the offsets for the start and stop predicates are - // just zero. - std::vector start_offsets{ir_builder.zeroVal()}; - std::vector stop_offsets{ir_builder.zeroVal()}; - - if (consumer_id == nullptr) { - return {start_offsets, stop_offsets}; + // just zero. All halo-related adjustments are done at root domains, + // so consumer_id is not a root domain, no adjustment is required. + if (consumer_id->definition() != nullptr && !non_divisible_pred) { + return { + GpuLower::current()->kernel()->zeroVal(), + GpuLower::current()->kernel()->zeroVal()}; } auto consumer_def = consumer_tv->definition(); + Val* start_offset = GpuLower::current()->kernel()->zeroVal(); + Val* stop_offset = GpuLower::current()->kernel()->zeroVal(); + // These adjustments are not required when predicating non-divisible splits if (!non_divisible_pred) { if (consumer_def->isA()) { - adjustStartAndStopOffsetsForShift( - start_offsets, - stop_offsets, - consumer_tv, - consumer_id, - padding_predicate); + std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForShift( + consumer_tv, consumer_id, padding_predicate); } else if (consumer_def->isA()) { - adjustStartAndStopOffsetsForGather( - start_offsets, - stop_offsets, + std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForGather( consumer_tv, consumer_id, - reference, - ref_start_index_map, - ref_stop_index_map, + consumer_start_index_map, + consumer_stop_index_map, padding_predicate); } // Adjustment for partial split - auto partial_split_offset = getGlobalConsumerOffsetWithPartialSplit( - gpu_lower->lowerValue(consumer_id)->as()); - for (auto& start_offset : start_offsets) { - start_offset = ir_builder.addExpr(start_offset, partial_split_offset); - } - for (auto& stop_offset : stop_offsets) { - stop_offset = ir_builder.addExpr(stop_offset, partial_split_offset); - } + auto partial_split_offset = + getGlobalConsumerOffsetWithPartialSplit(consumer_id); + start_offset = + SimplifyingIrBuilder::addExpr(start_offset, partial_split_offset); + stop_offset = + SimplifyingIrBuilder::addExpr(stop_offset, partial_split_offset); // If generating a predicate for unswitch, adjust the stop offset to // accommodate the addition of halo to the loop stop. See the @@ -2808,9 +2751,8 @@ std::pair, std::vector> getStartAndStopOffsets !padding_predicate, "Unswitch should not use the padding predicate"); auto stop_unswitch_offset = getUnswitchStopOffset(consumer_id, consumer_tv); - for (auto& stop_offset : stop_offsets) { - stop_offset = ir_builder.addExpr(stop_offset, stop_unswitch_offset); - } + stop_offset = + SimplifyingIrBuilder::addExpr(stop_offset, stop_unswitch_offset); } } @@ -2830,39 +2772,49 @@ std::pair, std::vector> getStartAndStopOffsets // index + (start_offset - start_limit) >= 0 // index + (stop_offset - stop_limit) < extent - for (auto& start_offset : start_offsets) { - start_offset = ir_builder.subExpr(start_offset, limits.first); - } - for (auto& stop_offset : stop_offsets) { - stop_offset = ir_builder.subExpr(stop_offset, limits.second); - } + start_offset = SimplifyingIrBuilder::subExpr(start_offset, limits.first); + stop_offset = SimplifyingIrBuilder::subExpr(stop_offset, limits.second); - return {start_offsets, stop_offsets}; + return {start_offset, stop_offset}; } -bool canOmitStartPredicate(kir::Val* start_offset) { +// A partial value of a start offset is returned if determined to be +// safe. Nullptr is returned if it can be omitted completely. +Val* simplifyStartOffset(Val* start_offset) { // Start predicate can be omitted when start_offset >= 0. - auto offset_val = start_offset->as()->value(); - return offset_val.has_value() && offset_val.value() >= 0; + auto offset_val = start_offset->as()->value(); + if (offset_val.has_value() && offset_val.value() >= 0) { + return nullptr; + } + + // start_offset may look like min(0, window_index - pad). Then, can + // remove min and leave the rhs only. + auto def = dynamic_cast(start_offset->definition()); + if (def != nullptr && def->getBinaryOpType() == BinaryOpType::Min && + def->lhs()->isZeroInt()) { + return def->rhs(); + } + + return start_offset; } bool canOmitStopPredicate( - kir::Val* stop_index, - kir::Val* stop_offset, - kir::IterDomain* kir_contig_id) { + Val* stop_index, + Val* stop_offset, + IterDomain* contig_id) { bool index_simple = stop_index->definition() == nullptr; // The definition may be just adding the magic zero, which can be // effectively considered "simple" if (!index_simple && isProtectedWithMagicZero(stop_index)) { // Make sure the lhs of stop_index is simple. - auto lhs = stop_index->definition()->as()->lhs(); + auto lhs = stop_index->definition()->as()->lhs(); if (lhs->definition() == nullptr) { index_simple = true; } } // Omit only when both the index and extent are "simple". - if (!(index_simple && kir_contig_id->extent()->definition() == nullptr)) { + if (!(index_simple && contig_id->extent()->definition() == nullptr)) { return false; } @@ -2873,33 +2825,32 @@ bool canOmitStopPredicate( // omitted if extent + halo + stop_offset < extent, i.e., halo + // stop_offset <= 0. - auto stop_offset_val = stop_offset->as()->value(); + auto stop_offset_val = stop_offset->as()->value(); - auto halo_ext = - gpu_lower->haloInfo().getRootAxisInfo(kir_contig_id).width()->value(); + auto halo_ext = gpu_lower->haloInfo().getRootAxisInfo(contig_id).width(); // If they are not compile-time constant, can't prove the // condition. - if (!stop_offset_val.has_value() || !halo_ext.has_value()) { + if (!stop_offset_val.has_value()) { return false; } - if (halo_ext.value() + stop_offset_val.value() > 0) { + if (halo_ext + stop_offset_val.value() > 0) { return false; } // When the domain is parallelized, the parallel dimension must be // exact. Otherwise, there would be extra threads/blocks that need // to be predicated out. - if (isParallelTypeThread(kir_contig_id->parallelType())) { + if (isParallelTypeThread(contig_id->getParallelType())) { if (!gpu_lower->parallelDimensionMap().isExact( - kir_contig_id->parallelType())) { + contig_id->getParallelType())) { return false; } // If the domain has halo, the loop is expanded by the halo // extent, so we can't prove the loop extent is the same as the // parallel dimension. - if (!(halo_ext.has_value() && halo_ext.value() == 0)) { + if (halo_ext != 0) { return false; } } @@ -2912,50 +2863,70 @@ bool canOmitStopPredicate( // Returns predicates and the concrete (by loop map) root domains they cover std::pair, ReferenceTensor> Index:: getReferenceRootPredicates( - const kir::TensorView* kir_consumer_tv, + TensorView* consumer_tv, const std::vector& loops, kir::ForLoop* unswitch_or_vec_loop, bool shift_padding) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates"); const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); + + const bool is_unswitch = unswitch_or_vec_loop != nullptr; // Nothing needs to be done when padding is not required. - if (shift_padding && !needsPadding(kir_consumer_tv->fuserTv())) { + if (shift_padding && !needsPadding(consumer_tv)) { return {{RootPredicateInfo::getFalseInfo()}, ReferenceTensor{}}; } - auto consumer_tv = kir_consumer_tv->fuserTv(); - // Get a reference tensor replayed as existing loop structure ReferenceTensor reference = IndexReferenceReplay::getReference(loops); // Generate halo information for reference. updateHaloInfoForReference(reference, consumer_tv); + const auto ref_2_consumer = indexMapReferenceTo( + consumer_tv, gpu_lower->caIndexMap(), reference.concrete_to_id); + + const auto reference_halo_extent_map = + getReferenceHaloExtentMap(reference, ref_2_consumer); + + auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv); + // Both start and stop positions may need to be predicated. Indexing // differs when generating predicates for unswitch. // NOTE: If we could find-and-replace KIR nodes, we could just // generate one index map, clone it and replace the loop-to-index // mappings of unswitched loops for the start predicate. - const auto ref_stop_index_map = getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, false); - // If not unswitch, share the same indexing map as the stop index map - const auto& ref_start_index_map = unswitch_or_vec_loop != nullptr - ? getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, true) - : ref_stop_index_map; - - auto ref_2_consumer = indexMapReferenceTo( - consumer_tv, gpu_lower->caIndexMap(), reference.concrete_to_id); + auto ref_stop_indexing = getPredicateReferenceIndexing( + loops, reference, unswitch_or_vec_loop, db_axis, false); + const auto consumer_stop_indexing = ref_stop_indexing.updateIndexCompute( + consumer_tv->domain(), + ref_2_consumer, + std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), + reference_halo_extent_map); + const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap(); + + // If not unswitch, share the same indexing map as the stop index + // map + std::unordered_map consumer_start_index_map; + if (is_unswitch) { + auto ref_start_indexing = getPredicateReferenceIndexing( + loops, reference, unswitch_or_vec_loop, db_axis, true); + const auto consumer_start_indexing = ref_start_indexing.updateIndexCompute( + consumer_tv->domain(), + ref_2_consumer, + std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), + reference_halo_extent_map); + consumer_start_index_map = consumer_start_indexing.indexMap(); + } else { + consumer_start_index_map = consumer_stop_index_map; + } // Get the contiguous ids we need to generate predicates for - auto contig_id_infos = - getPredicateContigIds(reference, consumer_tv, ref_2_consumer); + auto contig_id_infos = getPredicateContigIds(consumer_tv); auto non_divisible_splits = - getNonDivisibleReferenceDomainsToPredicate(consumer_tv, reference); + getNonDivisibleConsumerDomainsToPredicate(consumer_tv); contig_id_infos.insert( contig_id_infos.end(), non_divisible_splits.begin(), @@ -2972,52 +2943,22 @@ std::pair, ReferenceTensor> Index:: } auto root_ids = contig_id_entry.covered_ids; - auto kir_contig_id = - gpu_lower->lowerValue(contig_id)->as(); - const auto ref_stop_indexing_it = ref_stop_index_map.find(kir_contig_id); + const auto consumer_stop_indexing_it = + consumer_stop_index_map.find(contig_id); - // First condition below is due to broadcasts in consumers of consumer that - // are not in consumer there can be unresolved indexing in the reference - // tensor. This can happen when we have something like: TV3[i1o*i2, i1i] and - // TV1[i2] where tv3 and tv1 share their outer dimension. i1 will be part of - // reference tensors root domain, but when indexing into TV1 there aren't - // enough indices to resolve it. - // - // The condition also happens with Misaligned predicates, where + // First condition below happens with Misaligned predicates, where // inner-most vectorized loops are not included in the loops // parameter. Predicates involving vectorized loops are separately // generated in lower_misaligned_vectorization. // - // It can also happens with rfactored reductions. The reference - // tensor may include rfactored domains, so the contig id may be - // a root domain of the reference, not a rfactor root. Since - // there is no loop for rfactor domains, there's no indexing - // mapping for root domains. This seems safe as it can only happen - // with rfactor and rfactored tensors do not need predicates. - // // Second condition is simply to avoid predication on broadcasting axes as // it's not required. - if (ref_stop_indexing_it == ref_stop_index_map.end() || - ref_stop_indexing_it->second->isZeroInt()) { + if (consumer_stop_indexing_it == consumer_stop_index_map.end() || + consumer_stop_indexing_it->second->isZeroInt()) { continue; } - // Find a corresponding consumer root id if exists. Used to - // support shift. If a contig_id is a merged non-root domain, nothing - // is required to do for shift as shift-related domains are - // excluded from contig domains. - IterDomain* consumer_id = nullptr; - if (contig_id->definition() == nullptr || - contig_id_entry.is_non_divisible_split) { - auto it = ref_2_consumer.find(contig_id); - if (it != ref_2_consumer.end()) { - consumer_id = it->second; - } else { - continue; - } - } - RootPredicateInfo info; // Compute offsets for start and stop predicate. For non-shift, @@ -3032,53 +2973,50 @@ std::pair, ReferenceTensor> Index:: // The final predicates will look like: // (index + start_offset) >= 0 && (index + stop_offset) < extent. - std::tie(info.start_offsets_, info.stop_offsets_) = getStartAndStopOffsets( - consumer_id, + std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets( + contig_id, consumer_tv, reference, - ref_start_index_map, - ref_stop_index_map, + consumer_start_index_map, + consumer_stop_index_map, shift_padding, unswitch_or_vec_loop != nullptr, contig_id_entry.is_non_divisible_split); - auto stop_index = ref_stop_indexing_it->second; - auto start_index = ref_start_index_map.at(kir_contig_id); + auto stop_index = consumer_stop_indexing_it->second; + auto start_index = consumer_start_index_map.at(contig_id); // Build predicates for start positions as: // start_index + start_offset >= 0 - for (auto start_offset : info.start_offsets_) { - if (canOmitStartPredicate(start_offset)) { - info.start_predicates_.push_back(ir_builder.trueVal()); - continue; - } + auto start_offset = simplifyStartOffset(info.start_offset_); + if (start_offset == nullptr) { + info.start_predicate_ = GpuLower::current()->kernel()->trueVal(); + } else { auto offsetted_start_index = - ir_builder.addExpr(start_index, start_offset); - auto pred = - ir_builder.geExpr(offsetted_start_index, ir_builder.zeroVal()) - ->as(); - info.start_predicates_.push_back(pred); + SimplifyingIrBuilder::addExpr(start_index, start_offset); + auto start_pred = + SimplifyingIrBuilder::geExpr( + offsetted_start_index, GpuLower::current()->kernel()->zeroVal()) + ->as(); + info.start_predicate_ = start_pred; } // Build predicates for stop positions as: // stop_index + stop_offset < IterDomain::extent - for (auto stop_offset : info.stop_offsets_) { - if (canOmitStopPredicate(stop_index, stop_offset, kir_contig_id)) { - info.stop_predicates_.push_back(ir_builder.trueVal()); - continue; - } - auto offsetted_stop_index = ir_builder.addExpr(stop_index, stop_offset); - auto pred = - ir_builder.ltExpr(offsetted_stop_index, kir_contig_id->extent()) - ->as(); - info.stop_predicates_.push_back(pred); + auto stop_offset = info.stop_offset_; + if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) { + info.stop_predicate_ = GpuLower::current()->kernel()->trueVal(); + } else { + auto offsetted_stop_index = + SimplifyingIrBuilder::addExpr(stop_index, stop_offset); + auto stop_pred = SimplifyingIrBuilder::ltExpr( + offsetted_stop_index, contig_id->extent()) + ->as(); + info.stop_predicate_ = stop_pred; } - // Transform ids from reference to concrete and consumer domains - // (based on loop compute at map) - for (auto ref_id : contig_id_entry.covered_ids) { - info.root_ids_.insert(reference.id_to_concrete.at(ref_id)); - info.consumer_ids_.insert(ref_2_consumer.at(ref_id)); + for (auto consumer_id : contig_id_entry.covered_ids) { + info.root_ids_.insert(consumer_id); } pred_info_vec.emplace_back(info); } @@ -3089,7 +3027,7 @@ std::pair, ReferenceTensor> Index:: bool Index::protectWithMagicZero( kir::ForLoop* loop, IterDomain* reference_domain, - kir::Val* ind) { + Val* ind) { bool ref_dom_simple = (reference_domain == nullptr ? true : reference_domain->definition() != nullptr); @@ -3100,16 +3038,9 @@ bool Index::protectWithMagicZero( } RootPredicateInfo RootPredicateInfo::getFalseInfo() { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - RootPredicateInfo info; - info.start_predicates_.push_back(ir_builder.falseVal()); - info.stop_predicates_.push_back(ir_builder.falseVal()); - // These are just placeholder. When the predicate is false, the - // offset should not be used. - info.start_offsets_.push_back(nullptr); - info.stop_offsets_.push_back(nullptr); + info.start_predicate_ = GpuLower::current()->kernel()->falseVal(); + info.stop_predicate_ = GpuLower::current()->kernel()->falseVal(); return info; } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 83536067c19..27f1c911bde 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -69,30 +69,30 @@ class IndexCompute : public BackwardVisitor { void handle(Expr*) override; // return extent_map_[id] if exists, else return id->extent() - kir::Val* getExtent(kir::IterDomain* id); + Val* getExtent(IterDomain* id); //! True if a domain is not used to index - bool isZero(kir::IterDomain* id) const; + bool isZero(IterDomain* id) const; //! True if any dependent of a domain is not used to index - bool hasZeroMerged(kir::IterDomain* id) const; + bool hasZeroMerged(IterDomain* id) const; // Tensor domain we're mapping back to root const TensorDomain* td_; // NOLINT // Map we update as we propagate backward, containing all IDs in the // propagation. Initial indices are mapped with this map at tv->domain() - // and are back propagated to tv->rootDomain(). This index_map_ keeps the + // and are back propagated to tv->getRootDomain(). This index_map_ keeps the // indices at intermediate IterDomain's in that back propagation. - std::unordered_map index_map_; // NOLINT + std::unordered_map index_map_; // NOLINT // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to // the extent I0*I1. Also contains updated extents if we merge in a 0 index. // See zero_merged_in_. - std::unordered_map extent_map_; // NOLINT + std::unordered_map extent_map_; // NOLINT // Keeps track of domains that do not contribute to indexing - std::unordered_set zero_domains_; // NOLINT + std::unordered_set zero_domains_; // NOLINT // This set keeps track of IterDomain's that have had a zero index merged into // them. This happens if we do something like tv->axis(0)->split(4) then @@ -100,47 +100,46 @@ class IndexCompute : public BackwardVisitor { // indexing would be (0, i) then when we do the backward computation that zero // and i would attempt to be merged together. We handle indices like these // specially. - std::unordered_set zero_merged_in_; + std::unordered_set zero_merged_in_; // IDs that are a result of contiguous merges - std::unordered_set contig_ids; + std::unordered_set contig_ids; // Mentions if we should propagate an index down a particular IterDomain path // if there's an option - std::unordered_set preferred_paths_; + std::unordered_set preferred_paths_; // Map from IterDomains to halo-extended extents in corresponding // reference tensor - std::unordered_map reference_halo_extent_map_; + std::unordered_map reference_halo_extent_map_; public: - const std::unordered_map& indexMap() const { + const std::unordered_map& indexMap() const { return index_map_; } - const std::unordered_map& extentMap() const { + const std::unordered_map& extentMap() const { return extent_map_; } - const std::unordered_set& zeroDomains() const { + const std::unordered_set& zeroDomains() const { return zero_domains_; } - const std::unordered_set& zeroMergedIn() const { + const std::unordered_set& zeroMergedIn() const { return zero_merged_in_; } // Propagate back from _td using initial_index_map IndexCompute( const TensorDomain* _td, - std::unordered_map initial_index_map, - std::unordered_map _extent_map, - std::unordered_set zero_domains, - std::unordered_set _zero_merged_in, + std::unordered_map initial_index_map, + std::unordered_map _extent_map, + std::unordered_set zero_domains, + std::unordered_set _zero_merged_in, const std::vector& _root_contiguity, - std::unordered_set preferred_paths = {}, - std::unordered_map - reference_halo_extent_map = {}); + std::unordered_set preferred_paths = {}, + std::unordered_map reference_halo_extent_map = {}); // Updates index_map, extent_map, and zero_merged_in based on id_map and // returns a new IndexCompute ready to be used. @@ -148,8 +147,8 @@ class IndexCompute : public BackwardVisitor { const TensorDomain* new_td, const std::unordered_map& id_map, const std::vector& _root_contiguity, - const std::unordered_map& - reference_halo_extent_map = {}); + const std::unordered_map& reference_halo_extent_map = + {}); virtual void run(); }; @@ -159,10 +158,10 @@ class IndexSwizzle : public IndexCompute { public: IndexSwizzle( const TensorView* tv, - std::unordered_map initial_index_map, - std::unordered_map extent_map, - std::unordered_set zero_domains, - std::unordered_set zero_merged_in); + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in); void run() override; @@ -183,51 +182,45 @@ class RootPredicateInfo { friend class Index; public: - const auto& startPredicates() const { - return start_predicates_; + const auto& startPredicate() const { + return start_predicate_; } - auto& startPredicates() { - return start_predicates_; + auto& startPredicate() { + return start_predicate_; } - const auto& startOffsets() const { - return start_offsets_; + const auto& startOffset() const { + return start_offset_; } - const auto& stopPredicates() const { - return stop_predicates_; + const auto& stopPredicate() const { + return stop_predicate_; } - const auto& stopOffsets() const { - return stop_offsets_; + const auto& stopOffset() const { + return stop_offset_; } const auto& rootIds() const { return root_ids_; } - const auto& consumerIds() const { - return consumer_ids_; - } - //! Return a false RootPredicateInfo, i.e., both start and stop //! predicates are false. static RootPredicateInfo getFalseInfo(); private: - // prdicates for lower end - std::vector start_predicates_; - // prdicates for upper end - std::vector stop_predicates_; - // Offsets of the start predicate - std::vector start_offsets_; - // Offsets of the stop predicate - std::vector stop_offsets_; + // prdicate for lower end + Bool* start_predicate_ = nullptr; + // prdicate for upper end + Bool* stop_predicate_ = nullptr; + // Offset of the start predicate + Val* start_offset_ = nullptr; + // Offset of the stop predicate + Val* stop_offset_ = nullptr; // Track which roots have been handled by the generated predicates std::unordered_set root_ids_; - // Consumer IDs that correspond to root_ids_ - std::unordered_set consumer_ids_; }; // Simple interface for IndexCompute @@ -236,24 +229,24 @@ class RootPredicateInfo { class Index { private: // Producer indexing if it's in shared or local memory - static std::vector getNonGlobalProducerStridedIndices( + static std::vector getNonGlobalProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops); // Consumer indexing if it's in shared or local memory - static std::vector getNonGlobalConsumerStridedIndices( + static std::vector getNonGlobalConsumerStridedIndices( const TensorView* consumer, const std::vector& loops); // Producer if it's in global memory - static std::vector getGlobalProducerStridedIndices( + static std::vector getGlobalProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops); // Consumer indexing if it's in global memory - static std::vector getGlobalConsumerStridedIndices( + static std::vector getGlobalConsumerStridedIndices( const TensorView* consumer, const std::vector& loops); @@ -276,7 +269,7 @@ class Index { //! root domain of a producer tensor. The size of the returned //! vector is guaranteed to be equal to the number of axes of the //! indexing root domain. - static std::vector getProducerStridedIndices( + static std::vector getProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops); @@ -285,7 +278,7 @@ class Index { //! root domain of a consumer tensor. The size of the returned //! vector is guaranteed to be equal to the number of axes of the //! indexing root domain. - static std::vector getConsumerStridedIndices( + static std::vector getConsumerStridedIndices( const TensorView* consumer, const std::vector& loops); @@ -313,7 +306,7 @@ class Index { //! vectorized loop. static std::pair, ReferenceTensor> getReferenceRootPredicates( - const kir::TensorView* kir_consumer_tv, + TensorView* consumer_tv, const std::vector& loops, kir::ForLoop* unswitch_or_vec_loop, bool padding_predicate); @@ -328,7 +321,7 @@ class Index { static bool protectWithMagicZero( kir::ForLoop* loop, IterDomain* reference_domain = nullptr, - kir::Val* ind = nullptr); + Val* ind = nullptr); }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp index fcd0a8937ed..27e5b93e94e 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp @@ -1,11 +1,10 @@ #include #include +#include #include #include #include -#include -#include namespace torch { namespace jit { @@ -41,16 +40,12 @@ IterDomain* IndexReferenceReplay::idCopy(IterDomain* id) { // reduction. All we care about are the transformations, and trying to make // sure we track correctly a replaying with consistent reduction/broadcast // domains is challenging and unnecessary. - auto copied_id = - new IterDomain(id->start(), id->extent(), id->getParallelType()); + auto copied_id = IrBuilder::create( + id->container(), id->start(), id->extent(), id->getParallelType()); replayed_ids_.emplace_back(copied_id); return copied_id; } -IterDomain* IndexReferenceReplay::toFusionID(kir::IterDomain* kir_id) { - return ca_map_.toFusion(kir_id); -} - IterDomain* IndexReferenceReplay::toConcrete(IterDomain* id) { return ca_map_.getConcreteMappedID(id); } @@ -70,7 +65,8 @@ void IndexReferenceReplay::handle(Split* split) { } // Replay the provided split operation and add it to the reference DAG - new Split( + IrBuilder::create( + split->container(), ref_outer, ref_inner, ref_in, @@ -101,7 +97,7 @@ void IndexReferenceReplay::handle(Merge* merge) { } // Replay the provided merge operation and add it to the reference DAG - new Merge(ref_out, ref_outer, ref_inner); + IrBuilder::create(merge->container(), ref_out, ref_outer, ref_inner); // Mark producers and consumers ref_id_consumed_.emplace(ref_outer); @@ -149,7 +145,7 @@ TensorDomain* IndexReferenceReplay::computeReplay() { loop_structure_.begin(), loop_structure_.end(), std::back_inserter(domain_ids), - [this](kir::ForLoop* fl) { return toFusionID(fl->iter_domain()); }); + [](kir::ForLoop* fl) { return fl->iter_domain(); }); // IterVisitor based traversals don't work because we don't have all outputs. // backward traversal's traverseFrom(domain_ids) will throw "Invalid backward @@ -194,7 +190,7 @@ TensorDomain* IndexReferenceReplay::computeReplay() { // Construct a tensor that's representitive of the replayed loop structure. std::vector loops_replayed_domain; for (auto loop : loop_structure_) { - auto loop_id = toFusionID(loop->iter_domain()); + auto loop_id = loop->iter_domain(); // Map to loops with the loop map, but make sure the replayed id is actually // a leaf in the replay. auto ref_id_it = std::find_if( @@ -222,7 +218,7 @@ TensorDomain* IndexReferenceReplay::computeReplay() { loops_replayed_domain.begin(), loops_replayed_domain.end(), [](IterDomain* id) { return id->definition() != nullptr; })) { - auto domain = new TensorDomain( + auto domain = IrBuilder::create( // If there was no replay only return a domain with a root domain. loops_replayed_domain); return domain; @@ -257,8 +253,9 @@ TensorDomain* IndexReferenceReplay::computeReplay() { } // Create and return the reference. - auto domain = new TensorDomain( - {root_domain_ids.begin(), root_domain_ids.end()}, + auto domain = IrBuilder::create( + std::vector( + root_domain_ids.begin(), root_domain_ids.end()), loops_replayed_domain); return domain; } @@ -266,26 +263,30 @@ TensorDomain* IndexReferenceReplay::computeReplay() { IndexCompute getReferenceIndexing( const std::vector& loop_structure, - TensorDomain* reference_tensor) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - + TensorDomain* reference_tensor, + kir::ForLoop* double_buffer_loop) { // Create a simple index mapping from loop iter domains to their local index. // This is only applicable to global memory buffers. - std::unordered_map initial_index_map; + std::unordered_map initial_index_map; TORCH_INTERNAL_ASSERT(loop_structure.size() <= reference_tensor->nDims()); int magic_zero_loop = -1; for (const auto loop_i : c10::irange(loop_structure.size())) { auto ref_axis = reference_tensor->axis(loop_i); - auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as(); auto loop = loop_structure[loop_i]; auto ind = loop->index(); - ; - initial_index_map[kir_ref_axis] = ind; + initial_index_map[ref_axis] = ind; if (loop->vectorize()) { - initial_index_map[kir_ref_axis] = ir_builder.create(0); + initial_index_map[ref_axis] = GpuLower::current()->kernel()->zeroVal(); + } else if (double_buffer_loop == loop) { + // This version of getReferenceIndexing is only used for + // indexing global tensors. When indexing global producers, the + // index for a double buffered loop needs to be incremented. The + // parameter double_buffer_loop should be nullptr when indexing + // global consumers tensors. + initial_index_map[ref_axis] = + IrBuilder::addExpr(ind, GpuLower::current()->kernel()->oneVal()); } if (Index::protectWithMagicZero(loop, ref_axis, ind)) { @@ -295,10 +296,9 @@ IndexCompute getReferenceIndexing( // Add magic zero to a fairly inner most index if (magic_zero_loop >= 0) { - auto ref_id = gpu_lower->lowerValue(reference_tensor->axis(magic_zero_loop)) - ->as(); - initial_index_map[ref_id] = ir_builder.addExpr( - initial_index_map[ref_id], ir_builder.magicZeroVal()); + auto ref_id = reference_tensor->axis(magic_zero_loop); + initial_index_map[ref_id] = IrBuilder::addExpr( + initial_index_map[ref_id], FusionGuard::getCurFusion()->magicZeroVal()); } // Send to the other version of reference indexing that directly takes the @@ -310,19 +310,17 @@ IndexCompute getReferenceIndexing( IndexCompute getReferenceIndexing( const std::vector& loop_structure, TensorDomain* reference_tensor, - std::unordered_map index_map, - std::unordered_set zero_domains, + std::unordered_map index_map, + std::unordered_set zero_domains, std::unordered_set preferred_paths, - std::unordered_map halo_extent_map) { - auto gpu_lower = GpuLower::current(); - + std::unordered_map halo_extent_map) { // I thought this might be necesasry, but turns out it's not. I think it's // because of the root ordering above, however leaving it in case we find // out it is necessary in some cases. At the time of commiting, cuda-memcheck // passed without this. // - // std::unordered_map reference_extent_map; for (auto loop : loop_structure) { + // std::unordered_map reference_extent_map; for (auto loop : loop_structure) { // // If there's a broadcast merged in the for loop ID we want to track its // // extent // auto inputs = InputsOf::outputs( @@ -342,16 +340,6 @@ IndexCompute getReferenceIndexing( // } // } - // Convert to preferred_path to kir::IterDomain for IndexCompute - std::unordered_set kir_preferred_path; - std::transform( - preferred_paths.begin(), - preferred_paths.end(), - std::inserter(kir_preferred_path, kir_preferred_path.begin()), - [&gpu_lower](IterDomain* id) { - return gpu_lower->lowerValue(id)->as(); - }); - IndexCompute compute( reference_tensor, index_map, // NOLINT @@ -359,9 +347,9 @@ IndexCompute getReferenceIndexing( // in this function {}, zero_domains, - std::unordered_set(), + std::unordered_set(), reference_tensor->contiguity(), - kir_preferred_path, + preferred_paths, halo_extent_map); compute.run(); diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.h b/torch/csrc/jit/codegen/cuda/index_reference_replay.h index c4626213e76..fcb8e1f94e8 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.h +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -34,10 +34,6 @@ class IndexReferenceReplay : public OptInDispatch { // Make a new id for the reference replay based on the provided id IterDomain* idCopy(IterDomain* id); - // Use the compute at map to get the fusion IterDomain from the - // kir::IterDomain - IterDomain* toFusionID(kir::IterDomain* kir_id); - // Return the concrete entry of the non-reference id IterDomain* toConcrete(IterDomain* id); @@ -87,16 +83,17 @@ class IndexReferenceReplay : public OptInDispatch { IndexCompute getReferenceIndexing( const std::vector& loop_structure, TensorDomain* reference_domain, - std::unordered_map index_map, - std::unordered_set zero_domains, + std::unordered_map index_map, + std::unordered_set zero_domains, std::unordered_set preferred_path, - std::unordered_map halo_extent_map = {}); + std::unordered_map halo_extent_map = {}); // Short cut for global TVs. Index into the reference based on all loop indicies // in the loop structure. IndexCompute getReferenceIndexing( const std::vector& loop_structure, - TensorDomain* reference_domain); + TensorDomain* reference_domain, + kir::ForLoop* double_buffer_loop = nullptr); // When indexing there are sometimes an option to propagate an index down // multiple paths. This will return the IterDomains in the history of the diff --git a/torch/csrc/jit/codegen/cuda/instrumentation.cpp b/torch/csrc/jit/codegen/cuda/instrumentation.cpp index 52e16b3a7af..d227df0ab26 100644 --- a/torch/csrc/jit/codegen/cuda/instrumentation.cpp +++ b/torch/csrc/jit/codegen/cuda/instrumentation.cpp @@ -1,6 +1,6 @@ #include -#include +#include #ifdef _WIN32 #include diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index bd54d30811d..d21004ae154 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -15,13 +15,15 @@ C10_DEFINE_bool( C10_DEFINE_bool( torch_jit_nvfuser_horizontal_fusion, true, - "enable single node fusion for nvfuser"); + "enable horizontal fusion for nvfuser"); namespace torch { namespace jit { namespace fuser { namespace cuda { +static std::atomic cuda_fusion_guard_mode{true}; + bool getSingletonFusion() { return FLAGS_torch_jit_nvfuser_singleton_fusion; } @@ -42,8 +44,6 @@ bool setHorizontalFusion(bool value) { return old_value; } -static std::atomic cuda_fusion_guard_mode{true}; - std::atomic& getCudaFusionGuardMode() { return cuda_fusion_guard_mode; } @@ -329,6 +329,220 @@ RegisterOperators reg_guard({ aliasAnalysisFromSchema()), }); +// Infer dynamic axis (-1) in view_sizes given tensor_sizes +bool inferViewShape( + c10::List tensor_sizes, + c10::List view_sizes) { + int64_t dynamic_index = -1; + size_t view_size_num_elements = 1; + for (size_t idx = 0; idx < view_sizes.size(); ++idx) { + if (view_sizes[idx] == -1) { + TORCH_INTERNAL_ASSERT( + dynamic_index == -1, "Only one dimension can by inferred.") + dynamic_index = idx; + } else { + TORCH_INTERNAL_ASSERT(view_sizes[idx] > 0); + view_size_num_elements *= view_sizes[idx]; + } + } + const size_t kNumElements = std::accumulate( + tensor_sizes.begin(), tensor_sizes.end(), 1, std::multiplies<>()); + + if (kNumElements % view_size_num_elements != 0) { + return false; + } + + if (dynamic_index != -1) { + view_sizes[dynamic_index] = kNumElements / view_size_num_elements; + } + + return true; +} + +//! [ Note -- type guard logic in CudaFusionViewGuard ] +//! +//! CudaFusionViewGuard is used to guard input tensors to a `CudaFusionGroup` +//! that contains view operations, so that we would not feed inputs that +//! violate the graph defined in `GraphCache`. +//! +//! output = view(self, view-sizes) +//! +//! View Guard Inputs: +//! 1. self tensor_sizes - dynamic size List[Int] +//! 2. view_sizes - profile_ivalue List[Int] +//! 3. tensor_constraint - Constant List[Int] +//! 4. view_sizes_constraint - Constant List[Int] +//! +//! Things that we check: +//! 1. The #dimensions are the same for self tensor and its constraint +//! 2. The #dimensions are the same for view-sizes and its constraint +//! 3. Self tensor does not violate its constraint +//! a. Queue unrestricted sizes +//! b. Calculate #elements in self tensor +//! 4. view-sizes does not violate its constraint +//! a. Pop unrestricted sizes from queue +//! b. Calculate #elements in view-sizes +//! 5. The #elements is the same for self tensor and view-sizes +//! +//! Constraints: +//! A restricted axis creates a graph constraint, so its sizes is static. +//! An unrestricted axis is allowed to have a dynamic size, if it is consistent +//! between self tensor and view-sizes. It is marked with -1 in the constraint. +//! Only iterDomains with the Keep transform are dynamic. All other transforms +//! create a static constraint. +//! +bool checkViewGuard( + c10::List tensor_sizes, + c10::List view_sizes, + c10::List tensor_constraint, + c10::List view_sizes_constraint) { + // 1: Num Dimensions Check + if (tensor_constraint.size() != tensor_sizes.size() || + view_sizes_constraint.size() != view_sizes.size()) { + return false; + } + + // If axis allows dynamic sizes, then add tensor size to this queue. + // For dynamic axes in view_sizes, check that it is consistent with + // the corresponding tensor size. + std::queue dynamic_axis_queue; + + // 2. Tensor Static Check + int64_t tensor_size_product = 1; + for (const auto idx : c10::irange(tensor_sizes.size())) { + if (tensor_constraint[idx] == -1) { + dynamic_axis_queue.push(tensor_sizes[idx]); + } else if (tensor_constraint[idx] != tensor_sizes[idx]) { + return false; + } + tensor_size_product *= tensor_sizes[idx]; + } + + // 3. View-Sizes Static Check + int64_t view_size_product = 1; + for (const auto idx : c10::irange(view_sizes.size())) { + auto dynamic_size = (view_sizes_constraint[idx] == -1) + ? dynamic_axis_queue.front() + : view_sizes_constraint[idx]; + if (dynamic_size != view_sizes[idx]) { + return false; + } + view_size_product *= dynamic_size; + if (view_sizes_constraint[idx] == -1) { + dynamic_axis_queue.pop(); + } + } + + // 4. Check view invariant + // The number of elements in the input and output tensors are the same. + return tensor_size_product == view_size_product; +} + +//! +//! CudaFusionViewGuard Example Graph: +//! +//! graph(%self : __torch__.BiasViewRelu, +//! %inputs.1 : Tensor): +//! %2 : int = prim::Constant[value=-1]() # dynamic_bvg.py:50:40 +//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25 +//! %4 : NoneType = prim::Constant() +//! %5 : int[] = prim::Constant[value=[2, 3]]() +//! %6 : int[] = aten::size(%inputs.1) # dynamic_bvg.py:50:25 +//! %7 : int[] = aten::slice(%6, %4, %2, %3) # dynamic_bvg.py:50:25 +//! %view_shape.1 : int[] = aten::add(%7, %5) # dynamic_bvg.py:50:25 +//! %bias : Tensor = prim::GetAttr[name="bias"](%self) +//! %10 : int[] = aten::size(%bias) +//! %11 : int[] = prim::BroadcastSizes(%6, %10) +//! %12 : bool = prim::CudaFusionGuard[types=[...]](%inputs.1, %bias) +//! %13 : int[] = prim::Constant[value=[-1, -1, -1, 6]]() +//! %14 : int[] = prim::Constant[value=[-1, -1, -1, 2, 3]]() +//! %15 : bool = prim::CudaFusionViewGuard(%11, %view_shape.1, %13, %14) +//! %16 : bool[] = prim::ListConstruct(%15, %12) +//! %17 : bool = aten::all(%16) +//! %18 : Tensor = prim::If(%17) +//! block0(): +//! %19 : Tensor = prim::CudaFusionGroup_0[cache_id=0](%inputs.1, %bias) +//! -> (%19) +//! block1(): +//! %20 : Function = prim::Constant[name="fallback_fn", fallback=1]() +//! %21 : (...) = prim::CallFunction(%20, %inputs.1, %bias, %view_shape.1) +//! %22 : Float(...) = prim::TupleUnpack(%21) +//! -> (%22) +//! return (%18) +//! with prim::CudaFusionGroup_0 = graph(%0 : Float(...), +//! %1 : Float(...)): +//! %2 : int[] = prim::Constant[value=[2, 3, 4, 2, 3]]() +//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25 +//! %o.1 : Float(...) = aten::add(%0, %1, %3) # dynamic_bvg.py:51:16 +//! %5 : Float(...) = prim::view_copy(%o.1, %2) +//! %6 : Float(...) = aten::relu(%5) # dynamic_bvg.py:53:19 +//! return (%6) +//! +RegisterOperators view_guard({ + Operator( + "prim::CudaFusionViewGuard(...) -> bool", + // prim::CudaFusionViewGuard returns a fresh Boolean type without + // aliasing. if we would ever return refined tensor, which would change + // aliasing analysis, we should update aliasdb pass. + [](const Node* node) -> Operation { + return [](Stack& stack) { + // view_sizes_constraint - Constant List[Int] + at::ArrayRef inputs = last(stack, 4); + + // tensor_sizes is the runtime size for the self tensor + // tensor_sizes - dynamic size List[Int] + TORCH_INTERNAL_ASSERT( + inputs[0].isIntList(), "tensor_sizes needs to be Int List"); + auto tensor_sizes = inputs[0].toIntList(); + + // profiled_view_sizes is the runtime view size + // profiled_view_sizes - profile_ivalue List[Int] + TORCH_INTERNAL_ASSERT( + inputs[1].isIntList(), + "profiled_view_sizes needs to be Int list"); + auto profiled_view_sizes = inputs[1].toIntList(); + + // tensor_constraint is a constant List[Int] + // used to guard tensor_sizes + TORCH_INTERNAL_ASSERT( + inputs[2].isIntList(), + "tensor constraint needs to be Int List"); + auto tensor_constraint = inputs[2].toIntList(); + + // view_sizes_constraint is a constant List[Int] + // used to guard profiled_view_sizes + TORCH_INTERNAL_ASSERT( + inputs[3].isIntList(), + "view_sizes constraint needs to be Int List"); + auto view_sizes_constraint = inputs[3].toIntList(); + + // Drop after gather all input arguments + // If an argument is moved, it is destroyed when dropped from stack + drop(stack, 4); + + auto status = inferViewShape(tensor_sizes, profiled_view_sizes); + if (!status) { + push(stack, IValue(false)); + return; + } + + if (!fuser::cuda::getCudaFusionGuardMode()) { + push(stack, IValue(true)); + return; + } + + auto guard_status = checkViewGuard( + tensor_sizes, + profiled_view_sizes, + tensor_constraint, + view_sizes_constraint); + push(stack, IValue(guard_status)); + return; + }; + }, + aliasAnalysisFromSchema()), +}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) RegisterOperators reg_add_optional({ Operator( @@ -346,6 +560,160 @@ RegisterOperators reg_add_optional({ }, aliasAnalysisFromSchema()), }); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_view_copy({ + Operator( + "prim::view_copy(Tensor self, int[] size) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "view_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, size; + pop(stack, self, size); + push(stack, at::native::view(self.toTensor(), size.toIntVector())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_reshape_copy({ + Operator( + "prim::reshape_copy(Tensor self, int[] shape) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "reshape_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, shape; + pop(stack, self, shape); + push( + stack, + at::native::reshape(self.toTensor(), shape.toIntVector())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_squeeze_copy({ + Operator( + "prim::squeeze_copy(Tensor self) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "squeeze_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self; + pop(stack, self); + push(stack, at::squeeze(self.toTensor())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_squeeze_dim_copy({ + Operator( + "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "squeeze_dim_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, dim; + pop(stack, self, dim); + push(stack, at::squeeze(self.toTensor(), dim.toInt())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_unsqueeze_copy({ + Operator( + "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "unsqueeze_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, dim; + pop(stack, self, dim); + push(stack, at::unsqueeze(self.toTensor(), dim.toInt())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_infer_unsqueeze_size({ + Operator( + "prim::infer_unsqueeze_size(int[] a, int dim) -> int[]", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto dim = pop(stack).toInt(); + auto size = pop(stack).toIntVector(); + if (dim < 0) { + dim = dim + 1 + size.size(); + } + auto it = size.begin() + dim; + size.insert(it, 1); + push(stack, IValue(size)); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_infer_squeeze_dim_size({ + Operator( + "prim::infer_squeeze_size(int[] a, int dim) -> int[]", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto dim = pop(stack).toInt(); + auto size = pop(stack).toIntVector(); + if (dim < 0) { + dim = dim + size.size(); + } + auto it = size.begin() + dim; + if (*it == 1) { + size.erase(it); + } + push(stack, IValue(size)); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_infer_squeeze_size({ + Operator( + "prim::infer_squeeze_size.dim(int[] a) -> int[]", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto size = pop(stack).toIntVector(); + + for (auto it = size.begin(); it != size.end(); it++) { + if (*it == 1) { + auto pre = it - 1; + size.erase(it); + it = pre; + } + } + push(stack, IValue(size)); + }; + }, + aliasAnalysisFromSchema()), +}); + } // namespace } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/interface.h b/torch/csrc/jit/codegen/cuda/interface.h index 1ab9e6d8008..8afa854ea5c 100644 --- a/torch/csrc/jit/codegen/cuda/interface.h +++ b/torch/csrc/jit/codegen/cuda/interface.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index cf3d9c7a8c7..6a094c104df 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -1,8 +1,12 @@ #include #include #include +#include #include #include +#include +#include +#include #include #include @@ -20,16 +24,20 @@ namespace jit { namespace fuser { namespace cuda { +Statement::Statement(IrBuilderPasskey passkey) { + ir_container_ = passkey.ir_container_; +} + Statement::Statement(const Statement* src, IrCloner* ir_cloner) { - // IRCloner when cloning to a new fusion will copy the names of the original - // fusion. If we're cloning into the same fusion, we let Val and Expr get - // their names as usual by registering with the current fusion in their - // constructors, so don't overwrite that here. - if (src->fusion() != ir_cloner->fusion()) { - name_ = src->name_; - } - fusion_ = ir_cloner->fusion(); - ir_cloner->registerClone(src, this); + ir_container_ = ir_cloner->container(); +} + +void Statement::setName(IrContainerPasskey, StmtNameType name) { + name_ = name; +} + +void Statement::setName(IrBuilderPasskey, StmtNameType name) { + name_ = name; } Val* Statement::asVal() { @@ -42,24 +50,37 @@ Expr* Statement::asExpr() { return this->as(); } -void Statement::print() const { - IrPrinter ir_printer(std::cout); +std::string Statement::toString() const { + std::stringstream ss; + IrPrinter ir_printer(ss); ir_printer.handle(this); - std::cout << std::endl; + return ss.str(); } -// When we create a Val we immediately register them with the active fusion. -Val::Val(ValType _vtype, DataType _dtype, bool register_val) - : vtype_(_vtype), dtype_(_dtype) { - Fusion* fusion = FusionGuard::getCurFusion(); - TORCH_CHECK( - fusion != nullptr, "No active fusion group found when creating a Val."); - fusion_ = fusion; - if (register_val) { - name_ = fusion_->registerVal(this); - } +std::string Statement::toInlineString() const { + std::stringstream ss; + IrPrinter ir_printer(ss); + ir_printer.print_inline(this); + return ss.str(); } +Fusion* Statement::fusion() const { + TORCH_INTERNAL_ASSERT( + ir_container_->isA(), "Statement does not belong to a fusion."); + return ir_container_->as(); +} + +kir::Kernel* Statement::kernel() const { + TORCH_INTERNAL_ASSERT( + ir_container_->isA(), + "Statement does not belong to a kernel."); + return ir_container_->as(); +} + +// When we create a Val we immediately register them with the active fusion. +Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype) + : Statement(passkey), vtype_(_vtype), dtype_(_dtype) {} + // NOTE: we don't clone the definition_ and uses_ here // since they may introduce cloning cycles. Instead, we copy // the original pointers and we'll fix them up later part of the @@ -71,12 +92,7 @@ Val::Val(const Val* src, IrCloner* ir_cloner) vtype_(src->vtype_), dtype_(src->dtype_), is_fusion_input_(src->is_fusion_input_), - is_fusion_output_(src->is_fusion_output_) { - // If we're "cloning" into the same fusion, register with the fusion - if (src->fusion() == ir_cloner->fusion()) { - name_ = src->fusion()->registerVal(this); - } -} + is_fusion_output_(src->is_fusion_output_) {} const std::vector& Val::uses() const { if (vtype_ == ValType::TensorView) { @@ -92,33 +108,33 @@ namespace { // Traverse definition of all values involved in constructing the provided val. // Check if all values involved are constant values, meaning the provided // val is also a constant value. -class ConstCheck : OptOutConstDispatch { +class ConstCheck : private OptOutConstDispatch { private: bool is_const_ = true; - void handle(const Bool* b) override { + void handle(const Bool* b) final { is_const_ = is_const_ && b->isConst(); } - void handle(const Double* d) override { + void handle(const Double* d) final { is_const_ = is_const_ && d->isConst(); } - void handle(const Int* i) override { + void handle(const Int* i) final { is_const_ = is_const_ && i->isConst(); } - void handle(const NamedScalar* ns) override { + void handle(const NamedScalar* ns) final { is_const_ = is_const_ && false; } - void handle(const Expr* expr) override { + void handle(const Expr* expr) final { for (auto inp : expr->inputs()) { handle(inp); } } - void handle(const Val* val) override { + void handle(const Val* val) final { if (val->definition() != nullptr) { handle(val->definition()); } else { @@ -137,15 +153,18 @@ class ConstCheck : OptOutConstDispatch { } // namespace bool Val::isConstScalar() const { - if (!isScalar()) + if (!isScalar()) { return false; + } return ConstCheck::isConst(this); } c10::optional Val::getInt() const { if (isConstScalar() && isAnInt()) { if (this->getValType() == ValType::Scalar) { - return this->as()->value(); + if (this->isA()) { + return this->as()->value(); + } } } return c10::optional(); @@ -169,7 +188,7 @@ c10::optional Val::getDataType() const { bool Val::isProducerOf(const Val* other) const { TORCH_INTERNAL_ASSERT(other != nullptr); - TORCH_INTERNAL_ASSERT(fusion() == other->fusion()); + TORCH_INTERNAL_ASSERT(container() == other->container()); if (definition() == nullptr) { return false; @@ -186,23 +205,14 @@ bool Val::isConsumerOf(const Val* other) const { // We don't register with the active fusion in Expr as this needs to be done // after inputs and outputs are registered with the Expr -Expr::Expr(ExprType type) : type_{type} { - Fusion* fusion = FusionGuard::getCurFusion(); - if (fusion == nullptr) - TORCH_CHECK(false, "No active fusion group found when creating an Expr."); - fusion_ = fusion; -} +Expr::Expr(IrBuilderPasskey passkey, ExprType etype) + : Statement(passkey), etype_{etype} {} Expr::Expr(const Expr* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), - type_(src->type_), + etype_(src->etype_), inputs_(ir_cloner->clone(src->inputs_)), - outputs_(ir_cloner->clone(src->outputs_)) { - // If we're "cloning" into the same fusion, register with the fusion - if (src->fusion() == ir_cloner->fusion()) { - name_ = src->fusion()->registerExpr(this); - } -} + outputs_(ir_cloner->clone(src->outputs_)) {} bool Expr::sameAs(const Statement* other) const { if (this == other) { @@ -227,6 +237,30 @@ bool Expr::sameAs(const Statement* other) const { return true; } +kir::Predicate* Expr::predicate() const { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + return predicate_; +} + +void Expr::setPredicate(kir::Predicate* predicate) { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + predicate_ = predicate; +} + +kir::Predicate* Expr::writePredicate() const { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + return write_predicate_; +} + +void Expr::setWritePredicate(kir::Predicate* write_predicate) { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + write_predicate_ = write_predicate; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 2e0fa0885bd..1b8444fae46 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #include #include @@ -35,6 +35,8 @@ namespace jit { namespace fuser { namespace cuda { +using ValueId = int32_t; + using StmtNameType = unsigned int; constexpr StmtNameType kInvalidStmName = @@ -48,6 +50,22 @@ class UnaryOp; class BinaryOp; class IterDomain; class IrCloner; +class IrContainer; +class IrBuilderPasskey; +class IrContainerPasskey; + +namespace kir { +class Kernel; +class Predicate; +} // namespace kir + +// Passkey for container to register names with statements +class ExprPasskey { + friend class Expr; + + private: + explicit ExprPasskey() {} +}; TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; @@ -60,12 +78,12 @@ TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; //! is also important for the design to have a dispatch system for a Statment. //! Basically beinng able to succienctly traverse down the inhereitance stack of //! a Statment at runtime. This is currently implemented in dispatch.h -//! class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { friend void swap(Fusion&, Fusion&) noexcept; + friend void swap(IrContainer& a, IrContainer& b) noexcept; public: - Statement() = default; + Statement() = delete; // Cloning constructor Statement(const Statement* src, IrCloner* ir_cloner); @@ -78,7 +96,7 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { static void constDispatch(T handler, const Statement* const); template - static Statement* mutatorDispatch(T mutator, Statement*); + static void mutatorDispatch(T mutator, Statement*); // Accessor functions to types. Vals always have a DataType, Exprs never do virtual c10::optional getValType() const { @@ -106,8 +124,14 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { Expr* asExpr(); // Return the fusion this statement belongs to - Fusion* fusion() const { - return fusion_; + Fusion* fusion() const; + + // Return the kernel this statement belongs to + kir::Kernel* kernel() const; + + // Return the container this statement belongs to + IrContainer* container() const { + return ir_container_; } // Return the int that represents its name @@ -115,6 +139,13 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { return name_; } + // Set the statements' name. Typically the container will set the name, + // however if we're dealing with cloning, IrBuilder will set the name, this + // maybe should be from IrCloner, however I didn't want to add another + // passkey. + void setName(IrContainerPasskey, StmtNameType name); + void setName(IrBuilderPasskey, StmtNameType name); + virtual bool sameType(const Statement* const other) { if (isVal() && other->isVal()) return getValType().value() == other->getValType().value(); @@ -129,13 +160,17 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { return this == other; } - void print() const; + std::string toString() const; + std::string toInlineString() const; protected: + Statement(IrBuilderPasskey); + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) StmtNameType name_ = kInvalidStmName; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Fusion* fusion_ = nullptr; + IrContainer* ir_container_ = nullptr; }; //! A Val represents a "value." These are objects, like tensors, scalars, and @@ -169,34 +204,43 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { //! class TORCH_CUDA_CU_API Val : public Statement { public: - // We may not want to register this value during Val's constructor. The reason - // for this is that if we register the val, then in a derived constructor try - // to throw, fusion's destructor will get called, but the pointer to this Val - // will be invalid. When fusion tries to delete this value it will cause a seg - // fault, instead of showing the thrown error. explicit Val( + IrBuilderPasskey, ValType _vtype, - DataType _dtype = DataType::Null, - bool register_val = true); + DataType _dtype = DataType::Null); Val(const Val* src, IrCloner* ir_cloner); - // TODO: why is this optional? - // + // Dispatch functions, definitions in dispatch.cpp + template + static void dispatch(T handler, Val*); + + template + static void constDispatch(T handler, const Val* const); + + template + static void mutatorDispatch(T mutator, Val*); + c10::optional getValType() const override { return vtype_; } + ValType vtype() const { + return vtype_; + } + + DataType dtype() const { + return dtype_; + } + // Throws if no DataType is found. Vals must have a DataType - // - // TODO: why is this optional? - // c10::optional getDataType() const override; bool isScalar() const { return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar; } + // Returns if all dependencies are constant scalars bool isConstScalar() const; bool isAnInt() const { @@ -205,6 +249,11 @@ class TORCH_CUDA_CU_API Val : public Statement { c10::optional getInt() const; + // Returns if no dependencies and is a constant scalar. + virtual bool isConst() const { + return false; + } + bool isZeroInt() const; bool isOneInt() const; @@ -254,15 +303,11 @@ class TORCH_CUDA_CU_API Val : public Statement { return evaluator_index_; } - // Dispatch functions, definitions in dispatch.cpp - template - static void dispatch(T handler, Val*); - - template - static void constDispatch(T handler, const Val* const); - - template - static Statement* mutatorDispatch(T mutator, Val*); + // Following is managed by Fusion (or kirIrBuilder) and can change. + // TODO: Protect with a passkey. + void setDefinition(Expr* expr) { + definition_ = expr; + } protected: friend Fusion; @@ -272,19 +317,17 @@ class TORCH_CUDA_CU_API Val : public Statement { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const DataType dtype_; - // Following is managed by Fusion and can change. - void setDefinition(Expr* expr) { - definition_ = expr; - } - + // TODO: Add fusion passkey for this void setIsFusionInput(bool is_fusion_input) { is_fusion_input_ = is_fusion_input; } + // TODO: Add fusion passkey for this void setIsFusionOutput(bool is_fusion_output) { is_fusion_output_ = is_fusion_output; } + // TODO: Add fusion or container passkey for this void setUses(const std::vector& uses) { uses_ = uses; } @@ -297,6 +340,7 @@ class TORCH_CUDA_CU_API Val : public Statement { Expr* definition_ = nullptr; std::vector uses_; + // Expr evaluator idx; int evaluator_index_ = -1; }; @@ -342,15 +386,16 @@ class TORCH_CUDA_CU_API Val : public Statement { //! class TORCH_CUDA_CU_API Expr : public Statement { public: - explicit Expr(ExprType type); + explicit Expr(IrBuilderPasskey, ExprType type); + Expr(const Expr* src, IrCloner* ir_cloner); c10::optional getExprType() const override { - return type_; + return etype_; } - ExprType type() const { - return type_; + ExprType etype() const { + return etype_; } bool sameAs(const Statement* other) const override; @@ -380,23 +425,46 @@ class TORCH_CUDA_CU_API Expr : public Statement { static void constDispatch(T handler, const Expr* const); template - static Statement* mutatorDispatch(T mutator, Expr*); + static void mutatorDispatch(T mutator, Expr*); + + // TODO: Protect based on being in kernel container + kir::Predicate* predicate() const; + + // TODO: Protect based on being in kernel container + void setPredicate(kir::Predicate* predicate); + + // TODO: Protect based on being in kernel container + kir::Predicate* writePredicate() const; + + // TODO: Protect based on being in kernel container + void setWritePredicate(kir::Predicate* write_predicate); protected: + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); inputs_.push_back(input); } + // TODO: Add Fusion passkey void addOutput(Val* output) { TORCH_INTERNAL_ASSERT(output != nullptr); outputs_.push_back(output); } + ExprPasskey exprPasskey() { + return ExprPasskey(); + } + private: - ExprType type_ = ExprType::Invalid; + ExprType etype_ = ExprType::Invalid; std::vector inputs_; std::vector outputs_; + + kir::Predicate* predicate_ = nullptr; + + // Only used for reduction-related expressions + kir::Predicate* write_predicate_ = nullptr; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp similarity index 50% rename from torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp rename to torch/csrc/jit/codegen/cuda/ir_builder.cpp index ce3e17d74d2..17a4e59cfb6 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -1,35 +1,97 @@ -#include +#include +#include +#include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { -namespace kir { + +//! Clone an IR node, forwarding the arguments to the IrCloner constructor. +template +T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { + TORCH_INTERNAL_ASSERT( + ir_cloner != nullptr, + "Cannot use create when a cloner object is set. Use clone."); + + TORCH_INTERNAL_ASSERT( + ir_cloner->container() != nullptr, + "Cloner doesn't have a valid container to store cloned object."); + + T* dest = new T(src, ir_cloner); + const Statement* src_stmt = dynamic_cast(src); + Statement* dest_stmt = dynamic_cast(dest); + + auto dest_container = ir_cloner->container(); + auto src_container = src_stmt->container(); + + dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + + if (src_container != dest_container) { + dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); + } + + ir_cloner->registerClone(src_stmt, dest_stmt); + + return dest; +} + +#define IR_BUILDER_INSTANTIATE(T) \ + template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner); + +// Vals +IR_BUILDER_INSTANTIATE(IterDomain) +IR_BUILDER_INSTANTIATE(TensorDomain) +IR_BUILDER_INSTANTIATE(TensorView) +IR_BUILDER_INSTANTIATE(Bool) +IR_BUILDER_INSTANTIATE(Double) +IR_BUILDER_INSTANTIATE(Int) +IR_BUILDER_INSTANTIATE(NamedScalar) + +// Exprs +IR_BUILDER_INSTANTIATE(Split) +IR_BUILDER_INSTANTIATE(Merge) +IR_BUILDER_INSTANTIATE(TransposeOp) +IR_BUILDER_INSTANTIATE(ShiftOp) +IR_BUILDER_INSTANTIATE(GatherOp) +IR_BUILDER_INSTANTIATE(ViewOp) +IR_BUILDER_INSTANTIATE(UnaryOp) +IR_BUILDER_INSTANTIATE(BinaryOp) +IR_BUILDER_INSTANTIATE(TernaryOp) +IR_BUILDER_INSTANTIATE(ReductionOp) +IR_BUILDER_INSTANTIATE(WelfordOp) +IR_BUILDER_INSTANTIATE(BroadcastOp) Val* IrBuilder::newResult(DataType dtype) { switch (dtype) { case DataType::Bool: - return create(c10::nullopt); + return IrBuilder::create(c10::nullopt); case DataType::Double: - return create(c10::nullopt); + return IrBuilder::create(c10::nullopt); case DataType::Int: - return create(c10::nullopt); + return IrBuilder::create(c10::nullopt); default: TORCH_CHECK(false, "Unexpected data type"); } } Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { - TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types"); + TORCH_CHECK( + lhs->dtype() == rhs->dtype(), + "Incompatible operand types: ", + lhs->dtype(), + " and ", + rhs->dtype()); auto result = newResult(lhs->dtype()); - create(op_type, result, lhs, rhs); + IrBuilder::create(op_type, result, lhs, rhs); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return result; } Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { - auto result = create(c10::nullopt); - create(op_type, result, lhs, rhs); + auto result = IrBuilder::create(c10::nullopt); + IrBuilder::create(op_type, result, lhs, rhs); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return result; } @@ -37,37 +99,37 @@ Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) { TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types"); auto result = newResult(lhs->dtype()); - create(TernaryOpType::Where, result, pred, lhs, rhs); + IrBuilder::create(TernaryOpType::Where, result, pred, lhs, rhs); return result; } Val* IrBuilder::negExpr(Val* val) { auto result = newResult(val->dtype()); - create(UnaryOpType::Neg, result, val); + IrBuilder::create(UnaryOpType::Neg, result, val); return result; } Val* IrBuilder::notExpr(Val* val) { auto result = newResult(val->dtype()); - create(UnaryOpType::Not, result, val); + IrBuilder::create(UnaryOpType::Not, result, val); return result; } Val* IrBuilder::setExpr(Val* val) { auto result = newResult(val->dtype()); - create(UnaryOpType::Set, result, val); + IrBuilder::create(UnaryOpType::Set, result, val); return result; } Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) { - auto result = create(name, val->dtype()); - create(UnaryOpType::Set, result, val); + auto result = IrBuilder::create(name, val->dtype()); + IrBuilder::create(UnaryOpType::Set, result, val); return result; } Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) { - auto result = create(name, DataType::Int); - create(UnaryOpType::Address, result, val); + auto result = IrBuilder::create(name, DataType::Int); + IrBuilder::create(UnaryOpType::Address, result, val); return result; } @@ -127,45 +189,10 @@ Val* IrBuilder::minExpr(Val* lhs, Val* rhs) { return newArithmeticExpr(BinaryOpType::Min, lhs, rhs); } -Int* IrBuilder::zeroVal() { - if (zero_ == nullptr) { - zero_ = create(0); - } - return zero_; -} - -Int* IrBuilder::oneVal() { - if (one_ == nullptr) { - one_ = create(1); - } - return one_; -} - -Bool* IrBuilder::falseVal() { - if (false_ == nullptr) { - false_ = create(false); - } - return false_; -} - -Bool* IrBuilder::trueVal() { - if (true_ == nullptr) { - true_ = create(true); - } - return true_; -} - -NamedScalar* IrBuilder::magicZeroVal() { - if (magic_zero_ == nullptr) { - magic_zero_ = create(kMagicZeroName, DataType::Int); - } - return magic_zero_; -} - Val* SimplifyingIrBuilder::negExpr(Val* val) { - if (auto int_val = dynamic_cast(val)) { + if (auto int_val = dynamic_cast(val)) { if (int_val->isConst()) { - return create(-int_val->value().value()); + return IrBuilder::create(-int_val->value().value()); } } return IrBuilder::negExpr(val); @@ -175,9 +202,9 @@ Val* SimplifyingIrBuilder::notExpr(Val* val) { if (auto bool_val = dynamic_cast(val)) { if (bool_val->isConst()) { if (bool_val->value().value()) { - return falseVal(); + return FusionGuard::getCurFusion()->falseVal(); } else { - return trueVal(); + return FusionGuard::getCurFusion()->trueVal(); } } } @@ -188,13 +215,13 @@ Val* SimplifyingIrBuilder::addExpr(Int* lhs, Int::ScalarType rhs) { if (rhs == 0) { return lhs; } else if (lhs == nullptr) { - return IrBuilder::create(rhs); + return IrBuilder::IrBuilder::create(rhs); } else if (lhs->isConst()) { - return IrBuilder::create(lhs->value().value() + rhs); + return IrBuilder::IrBuilder::create(lhs->value().value() + rhs); } else if (rhs > 0) { - return IrBuilder::addExpr(lhs, IrBuilder::create(rhs)); + return IrBuilder::addExpr(lhs, IrBuilder::IrBuilder::create(rhs)); } else { - return IrBuilder::subExpr(lhs, IrBuilder::create(-rhs)); + return IrBuilder::subExpr(lhs, IrBuilder::IrBuilder::create(-rhs)); } } @@ -228,6 +255,15 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } } +Val* SimplifyingIrBuilder::addExpr(Val* lhs, Int::ScalarType rhs) { + auto lhs_int = dynamic_cast(lhs); + if (lhs_int != nullptr) { + return addExpr(lhs_int, rhs); + } else { + return addExpr(lhs, IrBuilder::create(rhs)); + } +} + Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) { return addExpr(lhs, negExpr(rhs)); } @@ -257,9 +293,9 @@ Val* SimplifyingIrBuilder::andExpr(Val* lhs, Val* rhs) { } if (lhs_definitely_true && rhs_definitely_true) { - return trueVal(); + return FusionGuard::getCurFusion()->trueVal(); } else if (lhs_definitely_false || rhs_definitely_false) { - return falseVal(); + return FusionGuard::getCurFusion()->falseVal(); } else if (lhs_definitely_true) { return rhs; } else if (rhs_definitely_true) { @@ -269,7 +305,65 @@ Val* SimplifyingIrBuilder::andExpr(Val* lhs, Val* rhs) { return IrBuilder::andExpr(lhs, rhs); } -} // namespace kir +namespace { + +template +Val* minOrMaxExpr( + Int* lhs, + Int* rhs, + IrBuilderFunc ir_builder_func, + IntFunc int_func) { + if (rhs == nullptr) { + return lhs; + } else if (lhs == nullptr) { + return rhs; + } else if (lhs->isConst() && rhs->isConst()) { + return IrBuilder::create( + int_func(lhs->value().value(), rhs->value().value())); + } else { + return ir_builder_func(lhs, rhs); + } +} + +template +Val* minOrMaxExpr( + Val* lhs, + Val* rhs, + IrBuilderFunc ir_builder_func, + IntFunc int_func) { + TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr); + if (lhs == nullptr) { + return rhs; + } else if (rhs == nullptr || lhs == rhs) { + return lhs; + } + auto lhs_int = dynamic_cast(lhs); + auto rhs_int = dynamic_cast(rhs); + if (lhs_int != nullptr && rhs_int != nullptr) { + return minOrMaxExpr(lhs_int, rhs_int, ir_builder_func, int_func); + } else { + return ir_builder_func(lhs, rhs); + } +} + +} // namespace + +Val* SimplifyingIrBuilder::maxExpr(Val* lhs, Val* rhs) { + return minOrMaxExpr( + lhs, + rhs, + [](Val* lhs, Val* rhs) { return IrBuilder::maxExpr(lhs, rhs); }, + [](int64_t lhs, int64_t rhs) { return std::max(lhs, rhs); }); +} + +Val* SimplifyingIrBuilder::minExpr(Val* lhs, Val* rhs) { + return minOrMaxExpr( + lhs, + rhs, + [](Val* lhs, Val* rhs) { return IrBuilder::minExpr(lhs, rhs); }, + [](int64_t lhs, int64_t rhs) { return std::min(lhs, rhs); }); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.h b/torch/csrc/jit/codegen/cuda/ir_builder.h new file mode 100644 index 00000000000..5087f2832a9 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_builder.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace kir { +class Kernel; +} + +class IrCloner; + +// Passkey for builder to register properties with statements, and to call +// functions in IrContainer +class TORCH_CUDA_CU_API IrBuilderPasskey { + friend class IrBuilder; + + public: + // TODO: Collapse ir_container and Kernel once Kernel inherits from + // IrContainer + IrContainer* const ir_container_ = nullptr; + + private: + explicit IrBuilderPasskey(IrContainer* ir_container); +}; + +//! IR builder interface +class TORCH_CUDA_CU_API IrBuilder { + public: + //! Allocate a new IR node, forwarding the arguments to the appropriate + //! constructor and registering with the container + template + static T* create(Args&&... args) { + auto container = FusionGuard::getCurFusion(); + // return create(container, std::forward(args)...); + TORCH_INTERNAL_ASSERT( + container != nullptr, "Need an active container to build IR."); + T* node = new T(IrBuilderPasskey(container), std::forward(args)...); + + container->registerStmt(IrBuilderPasskey(container), node); + + return node; + } + + //! Allocate a new IR node, forwarding the arguments to the appropriate + //! constructor and registering with the container + template + static T* create(IrContainer* container, Args&&... args) { + TORCH_INTERNAL_ASSERT( + container != nullptr, "Need an active container to build IR."); + T* node = new T(IrBuilderPasskey(container), std::forward(args)...); + + container->registerStmt(IrBuilderPasskey(container), node); + + return node; + } + + //! Clone an IR node, forwarding the arguments to the IrCloner constructor. + //! Register clones with IrCloner's target container. + template + static T* clone(const T* src, IrCloner* ir_cloner); + + // Unary operations + static Val* negExpr(Val* val); + static Val* notExpr(Val* val); + static Val* setExpr(Val* val); + static Val* setExprNamedScalar(const std::string& name, Val* val); + static Val* addressExprNamedScalar(const std::string& name, Val* val); + + // Binary operations + static Val* andExpr(Val* lhs, Val* rhs); + static Val* eqExpr(Val* lhs, Val* rhs); + static Val* gtExpr(Val* lhs, Val* rhs); + static Val* ltExpr(Val* lhs, Val* rhs); + static Val* leExpr(Val* lhs, Val* rhs); + static Val* geExpr(Val* lhs, Val* rhs); + static Val* addExpr(Val* lhs, Val* rhs); + static Val* subExpr(Val* lhs, Val* rhs); + static Val* mulExpr(Val* lhs, Val* rhs); + static Val* divExpr(Val* lhs, Val* rhs); + static Val* ceilDivExpr(Val* lhs, Val* rhs); + static Val* modExpr(Val* lhs, Val* rhs); + static Val* maxExpr(Val* lhs, Val* rhs); + static Val* minExpr(Val* lhs, Val* rhs); + + // Ternary operations + static Val* whereExpr(Val* pred, Val* lhs, Val* rhs); + + private: + static Val* newResult(DataType dtype); + static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs); + static Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs); +}; + +//! A wrapper builder with static expression simplification +//! +//! Example: +//! - addExpr(new Int(1), new Int(2)) -> Int(3) +//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo") +//! +//! Designed to be used to simplify predicate and index expressions in +//! generated code. Also, the shift validation may fail without +//! this simplification. +class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { + public: + static Val* negExpr(Val* val); + static Val* notExpr(Val* val); + + static Val* addExpr(Int* lhs, Int::ScalarType rhs); + static Val* addExpr(Val* lhs, Int::ScalarType rhs); + static Val* addExpr(Int* lhs, Int* rhs); + static Val* addExpr(Val* lhs, Val* rhs); + static Val* subExpr(Val* lhs, Val* rhs); + static Val* andExpr(Val* lhs, Val* rhs); + static Val* maxExpr(Val* lhs, Val* rhs); + static Val* minExpr(Val* lhs, Val* rhs); +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 7e5a9cfa8bc..8a1717e8d05 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -2,12 +2,15 @@ #include #include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { +IrCloner::IrCloner(IrContainer* container) : ir_container_(container) {} + Statement* IrCloner::clone(const Statement* statement) { if (statement == nullptr) { return nullptr; @@ -30,7 +33,6 @@ Statement* IrCloner::clone(const Statement* statement) { // that something went horribly wrong. TORCH_INTERNAL_ASSERT(new_node != nullptr); TORCH_INTERNAL_ASSERT(clones_map_[statement] == new_node); - TORCH_INTERNAL_ASSERT(new_node->fusion() == fusion_); return new_node; } @@ -39,7 +41,6 @@ Statement* IrCloner::clone(const Statement* statement) { void IrCloner::registerClone(const Statement* src, Statement* clone) { TORCH_CHECK(src != nullptr); TORCH_CHECK(clone != nullptr); - TORCH_CHECK(clone->fusion() == fusion_); TORCH_CHECK(clones_map_.insert({src, clone}).second); } @@ -56,79 +57,79 @@ void IrCloner::handle(const Expr* e) { } void IrCloner::handle(const TensorDomain* td) { - clone_ = new TensorDomain(td, this); + clone_ = IrBuilder::clone(td, this); } void IrCloner::handle(const IterDomain* id) { - clone_ = new IterDomain(id, this); + clone_ = IrBuilder::clone(id, this); } void IrCloner::handle(const Bool* b) { - clone_ = new Bool(b, this); + clone_ = IrBuilder::clone(b, this); } void IrCloner::handle(const Double* d) { - clone_ = new Double(d, this); + clone_ = IrBuilder::clone(d, this); } void IrCloner::handle(const Int* i) { - clone_ = new Int(i, this); + clone_ = IrBuilder::clone(i, this); } void IrCloner::handle(const NamedScalar* named_scalar) { - clone_ = new NamedScalar(named_scalar, this); + clone_ = IrBuilder::clone(named_scalar, this); } void IrCloner::handle(const TensorView* tv) { - clone_ = new TensorView(tv, this); + clone_ = IrBuilder::clone(tv, this); } void IrCloner::handle(const UnaryOp* op) { - clone_ = new UnaryOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const BinaryOp* op) { - clone_ = new BinaryOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const TernaryOp* op) { - clone_ = new TernaryOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const BroadcastOp* op) { - clone_ = new BroadcastOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const ReductionOp* op) { - clone_ = new ReductionOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const WelfordOp* op) { - clone_ = new WelfordOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const TransposeOp* op) { - clone_ = new TransposeOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const ShiftOp* op) { - clone_ = new ShiftOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const GatherOp* op) { - clone_ = new GatherOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const ViewOp* op) { - clone_ = new ViewOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const Split* split) { - clone_ = new Split(split, this); + clone_ = IrBuilder::clone(split, this); } void IrCloner::handle(const Merge* merge) { - clone_ = new Merge(merge, this); + clone_ = IrBuilder::clone(merge, this); } TensorView* RecomputeTv::recompute(TensorView* tv) { @@ -141,7 +142,7 @@ TensorView* RecomputeTv::recompute(TensorView* tv) { "Cannot recompute buffers that are inputs of the fusion."); // Grab all the expressions used to generate the TensorView - auto exprs = ExprSort::getExprs(tv->fusion(), {tv}); + auto exprs = StmtSort::getExprs(tv->fusion(), {tv}, false); // Run the replicator RecomputeTv replicator(tv->fusion(), exprs); @@ -161,7 +162,7 @@ TensorView* RecomputeTv::recompute(TensorView* tv) { } RecomputeTv::RecomputeTv(Fusion* fusion, std::vector exprs) - : IrCloner(fusion) { + : IrCloner(fusion), fusion_(fusion) { // Add inputs to the clones map to prevent cloning them. for (const auto inp : fusion->inputs()) { clones_map_[inp] = inp; @@ -183,7 +184,7 @@ void RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. auto exprs = - ExprSort::getExprs(fusion(), {td->domain().begin(), td->domain().end()}); + StmtSort::getExprs(fusion_, {td->domain().begin(), td->domain().end()}); for (auto expr : exprs) { IrCloner::handle(expr); diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index ac83d9edb09..1755b9e9563 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include #include #include @@ -11,7 +12,7 @@ namespace jit { namespace fuser { namespace cuda { -class Fusion; +class IrContainer; //! Clones nodes from an exiting Fusion //! @@ -21,10 +22,11 @@ class Fusion; //! class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { friend class Statement; + friend class IrBuilder; public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit IrCloner(Fusion* new_fusion) : fusion_(new_fusion) {} + explicit IrCloner(IrContainer* container); Statement* clone(const Statement* statement); @@ -45,8 +47,8 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { return copy; } - Fusion* fusion() const { - return fusion_; + IrContainer* container() const { + return ir_container_; } protected: @@ -86,12 +88,15 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { private: // The destination Fusion container - Fusion* fusion_ = nullptr; + IrContainer* ir_container_ = nullptr; // The dispatch interface doesn't allow returning values from // individual `handle()` methods, so they are storing the // result here Statement* clone_ = nullptr; + + // Builder to make all the new nodes + IrBuilder builder_; }; // Replicates all expressions used to generate the provided TensorView. Does not @@ -105,7 +110,9 @@ class RecomputeTv : private IrCloner { private: RecomputeTv(Fusion* fusion, std::vector exprs); - void handle(const TensorDomain*) override; + void handle(const TensorDomain*) final; + + Fusion* fusion_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ir_container.cpp b/torch/csrc/jit/codegen/cuda/ir_container.cpp new file mode 100644 index 00000000000..e84418eb973 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_container.cpp @@ -0,0 +1,279 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +void swap(IrContainer& a, IrContainer& b) noexcept { + FUSER_PERF_SCOPE("Fusion swap"); + + using std::swap; + + // Swap the content + swap(a.vals_up_, b.vals_up_); + swap(a.vals_, b.vals_); + + swap(a.exprs_up_, b.exprs_up_); + swap(a.exprs_, b.exprs_); + + swap(a.raw_ptrs_, b.raw_ptrs_); + + swap(a.val_type_name_map_, b.val_type_name_map_); + swap(a.expr_name_counter_, b.expr_name_counter_); + + // Fixup the Statement::fusion_ links for a + for (auto val : a.vals_) { + val->ir_container_ = &a; + } + for (auto expr : a.exprs_) { + expr->ir_container_ = &a; + } + + // Fixup the Statement::fusion_ links for b + for (auto val : b.vals_) { + val->ir_container_ = &a; + } + for (auto expr : b.exprs_) { + expr->ir_container_ = &a; + } +} + +IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { + to->clear(); + IrCloner ir_cloner(to); + + for (auto val : from->vals_) { + to->vals_.insert(ir_cloner.clone(val)); + } + + for (auto expr : from->exprs_) { + to->exprs_.insert(ir_cloner.clone(expr)); + } + + to->val_type_name_map_ = from->val_type_name_map_; + to->expr_name_counter_ = from->expr_name_counter_; + + return ir_cloner; +} + +IrContainer::IrContainer() = default; + +IrContainer::IrContainer(const IrContainer& other) { + FUSER_PERF_SCOPE("IrContainer copy"); + IrContainer::copy(&other, this); +} + +IrContainer::IrContainer(IrContainer&& other) noexcept { + FUSER_PERF_SCOPE("IrContainer move"); + swap(*this, other); +} + +IrContainer& IrContainer::operator=(const IrContainer& other) { + FUSER_PERF_SCOPE("IrContainer copy assign"); + IrContainer copy(other); + clear(); + swap(*this, copy); + return *this; +} + +IrContainer& IrContainer::operator=(IrContainer&& other) noexcept { + FUSER_PERF_SCOPE("IrContainer move assign"); + clear(); + swap(*this, other); + return *this; +} + +IrContainer::~IrContainer() { + clear(); +} + +//! Register the Statement with this container +void IrContainer::registerStmt(IrBuilderPasskey, Statement* stmt) { + if (stmt->isVal()) { + registerVal(stmt->asVal()); + } else { + registerExpr(stmt->asExpr()); + } +} + +//! Register the Val with this container +void IrContainer::registerVal(IrBuilderPasskey, Val* val) { + registerVal(val); +} + +//! Register expr with this container. +void IrContainer::registerExpr(IrBuilderPasskey, Expr* expr) { + registerExpr(expr); +} + +void IrContainer::registerExpr(ExprPasskey, Expr* expr) { + registerExpr(expr); +} + +void IrContainer::removeExpr(Expr* expr) { + TORCH_INTERNAL_ASSERT( + exprs_.find(expr) != exprs_.end(), + "Wanted to remove an expression but it doesn't exist in this container."); + auto expr_in_deque = std::find_if( + exprs_up_.begin(), + exprs_up_.end(), + [expr](std::unique_ptr& expr_up) { return expr_up.get() == expr; }); + + TORCH_INTERNAL_ASSERT( + expr_in_deque != exprs_up_.end(), + "Wanted to remove an expression but its unique ptr is missing."); + + exprs_.erase(expr); + exprs_up_.erase(expr_in_deque); + raw_ptrs_.erase((void*)expr); +} + +//! Completely remove val from the fusion, break all dependencies associated +//! with it +void IrContainer::removeVal(Val* val) { + // Don't remove shortcuts + if (val == true_val_.get() || val == false_val_.get() || + val == one_val_.get() || val == zero_val_.get() || + val == magic_zero_val_.get()) { + return; + } + + TORCH_INTERNAL_ASSERT( + vals_.find(val) != vals_.end(), + "Wanted to remove a value but it doesn't exist in this container."); + auto val_in_deque = std::find_if( + vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr& val_up) { + return val_up.get() == val; + }); + + TORCH_INTERNAL_ASSERT( + val_in_deque != vals_up_.end(), + "Wanted to remove a value but its unique ptr is missing."); + + vals_.erase(val); + vals_up_.erase(val_in_deque); + raw_ptrs_.erase((void*)val); +} + +//! Register the Val with this container +void IrContainer::registerVal(Val* val) { + if (inContainer(val)) { + return; + } + + vals_up_.emplace_back(std::unique_ptr(val)); + vals_.emplace(vals_up_.back().get()); + val->setName(IrContainerPasskey(), getValName(vals_up_.back()->vtype())); + raw_ptrs_.emplace((void*)vals_up_.back().get()); +} + +//! Register expr with this container. +void IrContainer::registerExpr(Expr* expr) { + if (inContainer(expr)) { + return; + } + exprs_up_.emplace_back(std::unique_ptr(expr)); + exprs_.emplace(exprs_up_.back().get()); + expr->setName(IrContainerPasskey(), getExprName()); + raw_ptrs_.emplace((void*)exprs_up_.back().get()); +} + +void IrContainer::clear() noexcept { + FUSER_PERF_SCOPE("IrContainer clear"); + vals_.clear(); + vals_up_.clear(); + exprs_.clear(); + exprs_up_.clear(); + raw_ptrs_.clear(); + + val_type_name_map_.clear(); + expr_name_counter_ = 0; +} + +bool IrContainer::inContainer(const Statement* stmt) const { + const void* const_void = (const void*)(stmt); + void* nonconst_void = const_cast(const_void); // NOLINT + if (raw_ptrs_.find(nonconst_void) == raw_ptrs_.end()) { + return false; + } + + TORCH_INTERNAL_ASSERT( + stmt->container() == this, + "Container claims to own stmt, but stmt disagrees."); + + Statement* nonconst_stmt = const_cast(stmt); // NOLINT + if (stmt->isExpr()) { + TORCH_INTERNAL_ASSERT( + exprs_.find(nonconst_stmt->as()) != exprs_.end(), + "Somehow container claims to and not to own an Expr."); + } + if (stmt->isVal()) { + TORCH_INTERNAL_ASSERT( + vals_.find(nonconst_stmt->as()) != vals_.end(), + "Somehow container claims to and not to own an Val."); + } + + return true; +} + +// Shortcuts for frequently used vals +Int* IrContainer::zeroVal() { + if (!zero_val_) { + auto zero_val = IrBuilder::create(this, 0); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == zero_val); + zero_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return zero_val_.get(); +} + +Int* IrContainer::oneVal() { + if (!one_val_) { + auto one_val = IrBuilder::create(this, 1); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == one_val); + one_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return one_val_.get(); +} + +Bool* IrContainer::falseVal() { + if (!false_val_) { + auto false_val = IrBuilder::create(this, false); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == false_val); + false_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return false_val_.get(); +} + +Bool* IrContainer::trueVal() { + if (!true_val_) { + auto true_val = IrBuilder::create(this, true); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == true_val); + true_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return true_val_.get(); +} + +NamedScalar* IrContainer::magicZeroVal() { + if (!magic_zero_val_) { + auto magic_zero = + IrBuilder::create(kMagicZeroName, DataType::Int); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == magic_zero); + magic_zero_val_ = std::unique_ptr( + vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return magic_zero_val_.get(); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_container.h b/torch/csrc/jit/codegen/cuda/ir_container.h new file mode 100644 index 00000000000..fb1aaeaf383 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_container.h @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class IrBuilderPasskey; +class ExprPasskey; +class OptOutMutator; + +class Int; +class Bool; +class NamedScalar; + +// Passkey for container to register names with statements +class IrContainerPasskey { + friend class IrContainer; + + private: + explicit IrContainerPasskey() {} +}; + +class TORCH_CUDA_CU_API IrContainer : public PolymorphicBase { + public: + IrContainer(); + + IrContainer(const IrContainer& other); + IrContainer(IrContainer&& other) noexcept; + + IrContainer& operator=(const IrContainer& other); + IrContainer& operator=(IrContainer&& other) noexcept; + + virtual ~IrContainer(); + + bool inContainer(const Statement* stmt) const; + + void assertInContainer(const Statement* stmt, const std::string& msg) const { + TORCH_CHECK( + inContainer(stmt), msg, " it was not found in the active container."); + } + + //! Return in insertion order + const std::deque deterministic_vals() const noexcept { + std::deque vals_deque; + std::transform( + vals_up_.begin(), + vals_up_.end(), + std::back_inserter(vals_deque), + [](const std::unique_ptr& val_up) { return val_up.get(); }); + return vals_deque; + } + + //! Register the Statement with this container + virtual void registerStmt(IrBuilderPasskey, Statement* stmt); + + //! Register the Val with this container + virtual void registerVal(IrBuilderPasskey, Val* val); + + //! Register expr with this container. + virtual void registerExpr(IrBuilderPasskey, Expr* expr); + + //! Allow expr's to register themselves with a container, this is only used + //! for broadcastOp so it can register itself in its constructor so root maps + //! can be built. + virtual void registerExpr(ExprPasskey, Expr* expr); + + //! Return the set of Exprs registered with this fusion. Warning: This will + //! return exprs outside inputs/outputs, so can be unsafe for use with + //! segmented fusions. + const std::unordered_set& unordered_exprs() const noexcept { + return exprs_; + } + + //! Return the set of Vals registered with this fusion + const std::unordered_set& vals() const noexcept { + return vals_; + } + + // Shortcuts for frequently used vals + Int* zeroVal(); + Int* oneVal(); + Bool* falseVal(); + Bool* trueVal(); + NamedScalar* magicZeroVal(); + + protected: + static IrCloner copy(const IrContainer* from, IrContainer* to); + + friend void swap(IrContainer& a, IrContainer& b) noexcept; + + // Let mutator remove Exprs. + friend OptOutMutator; + + virtual void removeExpr(Expr* expr); + + //! Completely remove val from the fusion, break all dependencies associated + //! with it + virtual void removeVal(Val* val); + + //! Register the Val with this container + virtual void registerVal(Val* val); + + //! Register expr with this container. + virtual void registerExpr(Expr* expr); + + StmtNameType getValName(ValType vtype) { + if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { + val_type_name_map_[vtype] = 0; + } + return val_type_name_map_[vtype]++; + } + + StmtNameType getExprName() { + return expr_name_counter_++; + } + + void clear() noexcept; + + // Deque of unique pointer is the memory owning data structure + std::deque> vals_up_; + + // A convenient set to return when we just need an unordered set to do + // something like check if a Val is in this container + std::unordered_set vals_; + + // Deque of unique pointer is the memory owning data structure + std::deque> exprs_up_; + + // A convenient set to return when we just need an unordered set to do + // something like check if an Expr is in this container + std::unordered_set exprs_; + + // Used to implement a generic "inContainer" that can be passed an invalid + // pointer. Specifically a pointer to a Statement owned by another container + // that has been freed. We can't check normally with the unordered_sets we + // already have because it would require a const_cast from a constant + // expr/val, or a dynamic cast from a Statement. + std::unordered_set raw_ptrs_; + + // Values names counters + std::unordered_map val_type_name_map_; + + // Expression names counter + StmtNameType expr_name_counter_ = 0; + + // Manually store some persistent, frequently used nodes. It's very + // challenging to do this anything but manually as detecting when a container + // may or may not have one of these vals is tricky. Specifically because if + // the container doesn't own it, it's hard to understand from the outside if + // the node may have been removed then re-registered. It could also be tricky + // to know when we're using a different container as in FusionCopy_test + // demonstrates deleting then creating containers can result in the same + // pointer for the container. + std::unique_ptr true_val_; + std::unique_ptr false_val_; + std::unique_ptr one_val_; + std::unique_ptr zero_val_; + std::unique_ptr magic_zero_val_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index 5ca8d54aaa9..7511fbd4d6d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -303,13 +304,13 @@ void IrGraphGenerator::generateScheduleGraph() { // Maybe not the best way to handle the root domain, but should be okay addArc( tv, - new TensorDomain(tv->getRootDomain()), + IrBuilder::create(tv->getRootDomain()), "[style=dashed, color=green, arrowhead=none]"); if (tv->domain()->hasRFactor()) addArc( tv, - new TensorDomain(tv->domain()->getRFactorDomain()), + IrBuilder::create(tv->domain()->getRFactorDomain()), "[style=dashed, color=green, arrowhead=none]"); } } diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index 1144d95eb15..f9b3adf703d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 02c319d3665..28478c64d91 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -19,6 +19,9 @@ namespace cuda { class WelfordResult; class ViewTransform; +class IrCloner; +class IrBuilderPasskey; + //! A Bool value //! //! This value can be a symbolic value (defined after the kernel @@ -26,17 +29,18 @@ class ViewTransform; //! class TORCH_CUDA_CU_API Bool : public Val { public: - Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {} + Bool(IrBuilderPasskey passkey); + + explicit Bool(IrBuilderPasskey passkey, bool value); - explicit Bool(bool value) - : Val(ValType::Scalar, DataType::Bool), maybe_value_{value} {} + explicit Bool(IrBuilderPasskey passkey, c10::optional value); Bool(const Bool* src, IrCloner* ir_cloner); bool isSymbolic() const { return !(maybe_value_.has_value()); } - bool isConst() const { + bool isConst() const final { return maybe_value_.has_value(); } c10::optional value() const { @@ -56,18 +60,18 @@ class TORCH_CUDA_CU_API Double : public Val { public: using ScalarType = double; - Double() - : Val(ValType::Scalar, DataType::Double), maybe_value_{c10::nullopt} {} + Double(IrBuilderPasskey passkey); + + explicit Double(IrBuilderPasskey passkey, ScalarType value); - explicit Double(ScalarType value) - : Val(ValType::Scalar, DataType::Double), maybe_value_{value} {} + explicit Double(IrBuilderPasskey passkey, c10::optional value); Double(const Double* src, IrCloner* ir_cloner); bool isSymbolic() const { return !(maybe_value_.has_value()); } - bool isConst() const { + bool isConst() const final { return maybe_value_.has_value(); } c10::optional value() const { @@ -86,17 +90,18 @@ class TORCH_CUDA_CU_API Int : public Val { public: using ScalarType = int64_t; - Int() : Val(ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {} + Int(IrBuilderPasskey passkey); - explicit Int(ScalarType value) - : Val(ValType::Scalar, DataType::Int), maybe_value_{value} {} + explicit Int(IrBuilderPasskey passkey, ScalarType value); + + explicit Int(IrBuilderPasskey passkey, c10::optional value); Int(const Int* src, IrCloner* ir_cloner); bool isSymbolic() const { return !(maybe_value_.has_value()); } - bool isConst() const { + bool isConst() const final { return maybe_value_.has_value(); } c10::optional value() const { @@ -152,14 +157,18 @@ class TVDomainGuard; class TORCH_CUDA_CU_API TensorView : public Val { public: TensorView( + IrBuilderPasskey passkey, TensorDomain* domain, DataType dtype, MemoryType mtype = MemoryType::Local); - explicit TensorView(const std::shared_ptr& tensor_type); + explicit TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& tensor_type); - explicit TensorView(const std::shared_ptr& jit_value) - : TensorView(jit_value->type()->cast()) {} + explicit TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& jit_value); TensorView(const TensorView* src, IrCloner* ir_cloner); @@ -187,6 +196,16 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! trivial reductions bool hasAnyReduction() const; + //! Returns true if this tensor is zero dimensional, + //! i.e. a wrapped scalar or an empty placeholder. + bool isZeroDim() const { + return nDims() == 0; + } + + //! Returns true if this tensor does not contain + //! any value. + bool isEmptyTensor() const; + c10::optional getReductionAxis() const; const std::vector& getRootDomain() const; @@ -210,6 +229,24 @@ class TORCH_CUDA_CU_API TensorView : public Val { size_t nDims() const; + // sets cpu_scalar_ value, which is special handling for CPU based zero-dim + // tensors (i.e. CPU Tensors that only have one value). This is only used if + // on an input value, otherwise ignored. This is important as special handling + // because these "scalars" should be type promoted as a tensor, but we want to + // avoid explicit copying of the data, so we want to pass the data value as a + // standard kernel argument value. + void setCpuScalar(bool is_cpu_scalar); + + // returns cpu_scalar_ value, which is special handling for CPU based zero-dim + // tensors (i.e. CPU Tensors that only have one value). This is only used if + // on an input value, otherwise ignored. This is important as special handling + // because these "scalars" should be type promoted as a tensor, but we want to + // avoid explicit copying of the data, so we want to pass the data value as a + // standard kernel argument value. + bool isCpuScalar() const { + return cpu_scalar_; + } + // Returns the position that this tensor is produced at relative to its axes. unsigned int getComputeAtPosition() const { return compute_at_pos_; @@ -356,6 +393,13 @@ class TORCH_CUDA_CU_API TensorView : public Val { return axes_to_swizzle_; } + // Apply double buffering transformation + void doubleBuffer(); + + bool isDoubleBuffered() const { + return is_double_buffered_; + } + friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; @@ -393,6 +437,14 @@ class TORCH_CUDA_CU_API TensorView : public Val { MemoryType memory_type_ = MemoryType::Local; SwizzleType swizzle_type_ = SwizzleType::NoSwizzle; std::vector axes_to_swizzle_; + bool is_double_buffered_ = false; + // special handling for CPU based zero-dim tensors (i.e. CPU Tensors that only + // have one value). This is only used if on an input value, otherwise ignored. + // This is important as special handling because these "scalars" should be + // type promoted as a tensor, but we want to avoid explicit copying of the + // data, so we want to pass the data value as a standard kernel argument + // value. + bool cpu_scalar_ = false; }; //! A simple TensorView builder diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 8fd4475d2dd..bb494148be2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -1,10 +1,11 @@ #pragma once -#include +#include #include #include #include +#include //! Nodes in here should generally not be used by users. They should be behind //! the scenes and users shouldn't have to be aware of what they do to use the @@ -20,6 +21,8 @@ namespace fuser { namespace cuda { class ViewTransform; +class Scope; +class IrCloner; //! Returns true if both v1 and v2 are scalars, are the same type of scalars, //! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both @@ -34,7 +37,7 @@ bool areEqualScalars(Val* v1, Val* v2); //! 4) split/merge class TORCH_CUDA_CU_API UnaryOp : public Expr { public: - UnaryOp(UnaryOpType type, Val* out, Val* in); + UnaryOp(IrBuilderPasskey, UnaryOpType type, Val* out, Val* in); UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); @@ -63,7 +66,7 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { //! 2) LT (A < B) class TORCH_CUDA_CU_API BinaryOp : public Expr { public: - BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs); + BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); @@ -97,7 +100,11 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { //! \param out The output tensor //! \param in The input tensor //! \param is_broadcast_dims True when output dim is a new broadcast domain - BroadcastOp(Val* out, Val* in, std::vector is_broadcast_dims); + BroadcastOp( + IrBuilderPasskey, + Val* out, + Val* in, + std::vector is_broadcast_dims); BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); @@ -138,7 +145,12 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { //! non-reduction/non-broadcast dimensions. class TORCH_CUDA_CU_API ReductionOp : public Expr { public: - ReductionOp(BinaryOpType reduction_op_type, Val* init, Val* out, Val* in); + ReductionOp( + IrBuilderPasskey, + BinaryOpType reduction_op_type, + Val* init, + Val* out, + Val* in); ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); @@ -169,6 +181,7 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { class TORCH_CUDA_CU_API WelfordOp : public Expr { public: WelfordOp( + IrBuilderPasskey, Val* out_avg, Val* out_var, Val* out_N, @@ -189,10 +202,6 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { return in_avg_; } - Val* init() const { - return init_avg_; - } - bool sameAs(const Statement* const other) const override; // Welford Accessors @@ -255,7 +264,11 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { class TORCH_CUDA_CU_API TransposeOp : public Expr { public: - TransposeOp(TensorView* out, TensorView* in, std::vector new2old); + TransposeOp( + IrBuilderPasskey, + TensorView* out, + TensorView* in, + std::vector new2old); TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); @@ -279,7 +292,13 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { class TORCH_CUDA_CU_API TernaryOp : public Expr { public: - TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3); + TernaryOp( + IrBuilderPasskey, + TernaryOpType type, + Val* out, + Val* in1, + Val* in2, + Val* in3); TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); @@ -317,7 +336,12 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { //! \param out //! \param in //! \param offsets - ShiftOp(Val* out, Val* in, std::vector offsets, bool pad); + ShiftOp( + IrBuilderPasskey, + Val* out, + Val* in, + std::vector offsets, + std::vector pad_width); ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); @@ -336,8 +360,14 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { return offsets_; } - bool pad() const { - return pad_; + const std::vector& padWidth() const { + return pad_width_; + } + + bool hasPadding() const { + return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto p) { + return p > 0; + }); } bool sameAs(const Statement* other) const override; @@ -349,17 +379,18 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { //! offsets_. The sign of each value indicates the direction of //! shifting. const std::vector offsets_; - const bool pad_; + const std::vector pad_width_; }; //! Gather a window around each element. class TORCH_CUDA_CU_API GatherOp : public Expr { public: GatherOp( + IrBuilderPasskey, Val* out, Val* in, - std::vector window_shape, - std::vector> pad_width); + std::vector window_shape, + std::vector> pad_width); GatherOp(const GatherOp* src, IrCloner* ir_cloner); @@ -381,20 +412,26 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { return pad_width_; } + bool hasPadding() const { + return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto& p) { + return p[0] > 0 || p[1] > 0; + }); + } + bool sameAs(const Statement* other) const override; private: Val* const out_ = nullptr; Val* const in_ = nullptr; //! Shape of a window gathered for each element. - std::vector window_shape_; + std::vector window_shape_; //! The size of zero-padding of each axis. - std::vector> pad_width_; + std::vector> pad_width_; }; class TORCH_CUDA_CU_API ViewOp : public Expr { public: - ViewOp(TensorView* out, TensorView* in); + ViewOp(IrBuilderPasskey, TensorView* out, TensorView* in); ViewOp(const ViewOp* src, IrCloner* ir_cloner); @@ -422,6 +459,7 @@ class IndexReferenceReplay; class TORCH_CUDA_CU_API IterDomain : public Val { public: IterDomain( + IrBuilderPasskey, Val* start, Val* extent, ParallelType parallel_type = ParallelType::Serial, @@ -429,6 +467,7 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool is_rfactor_domain = false); IterDomain( + IrBuilderPasskey, Val* start, Val* extent, Val* stop_offset, @@ -441,20 +480,7 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool sameAs(const Statement* other) const override; // Returns a new IterDomain matching properties of this - // TODO: parallel_method->getParallelType - IterDomain* clone() const { - auto cloned = new IterDomain( - start(), - extent(), - stopOffset(), - getParallelType(), - getIterType(), - isRFactorProduct()); - - cloned->is_padded_dimension_ = is_padded_dimension_; - cloned->padded_to_size_ = padded_to_size_; - return cloned; - } + IterDomain* clone() const; //! Clone a vector domains static std::vector clone( @@ -631,6 +657,11 @@ class TORCH_CUDA_CU_API IterDomain : public Val { //! domain. std::pair stridedSplit(int factor); + // TODO: Remove + bool isSimple() const { + return definition() == nullptr; + } + protected: friend TensorDomain; friend ReplayTransformations; @@ -647,6 +678,10 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool is_rfactor_domain_ = false; bool is_padded_dimension_ = false; c10::optional padded_to_size_ = c10::nullopt; + + // TODO: Remove only used in kernel IR because IterDomains don't maintain + // definitions of split/merge. + bool is_simple_ = true; }; //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every @@ -666,15 +701,18 @@ class TORCH_CUDA_CU_API IterDomain : public Val { class TORCH_CUDA_CU_API TensorDomain : public Val { public: explicit TensorDomain( + IrBuilderPasskey, std::vector root_domain, std::vector contiguity = std::vector()); TensorDomain( + IrBuilderPasskey, std::vector root_domain, std::vector domain, std::vector contiguity = std::vector()); TensorDomain( + IrBuilderPasskey, std::vector root_domain, std::vector rfactor_domain, std::vector domain, @@ -718,6 +756,8 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { bool hasReduction() const; bool hasBlockReduction() const; bool hasGridReduction() const; + bool hasBlockBroadcast() const; + bool hasGridBroadcast() const; bool hasBroadcast() const; bool hasRFactor() const; bool hasVectorize() const; @@ -821,6 +861,7 @@ class TORCH_CUDA_CU_API Split : public Expr { // start_offset and stop_offset are distance from the left end and // right ends, respectively. Split( + IrBuilderPasskey, IterDomain* outer, IterDomain* inner, IterDomain* in, @@ -881,12 +922,13 @@ class TORCH_CUDA_CU_API Split : public Expr { //! dictate which will be traversed first (inner). Both IterDomains must be of //! the same iter or reduction type, as well as the same parallelization //! strategy if there is one -//! -//! \todo Should this be a unary op type? -//! class TORCH_CUDA_CU_API Merge : public Expr { public: - Merge(IterDomain* out, IterDomain* outer, IterDomain* inner); + Merge( + IrBuilderPasskey, + IterDomain* out, + IterDomain* outer, + IterDomain* inner); Merge(const Merge* src, IrCloner* ir_cloner); @@ -918,9 +960,7 @@ class TORCH_CUDA_CU_API Merge : public Expr { //! class TORCH_CUDA_CU_API NamedScalar : public Val { public: - // NOLINTNEXTLINE(modernize-pass-by-value) - NamedScalar(std::string name, DataType dtype) - : Val(ValType::NamedScalar, dtype), name_(name) {} + NamedScalar(IrBuilderPasskey passkey, std::string name, DataType dtype); NamedScalar(const NamedScalar* src, IrCloner* ir_cloner); @@ -931,9 +971,11 @@ class TORCH_CUDA_CU_API NamedScalar : public Val { bool sameAs(const Statement* other) const override; //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x) + //! WARNING: Only works with Fusion container at the moment static NamedScalar* getParallelDim(ParallelType p_type); //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x) + //! WARNING: Only works with Fusion container at the moment static NamedScalar* getParallelIndex(ParallelType p_type); //! Return the parallel type of this NamedScalar if it is an extent of a diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index a553c59fc2b..8c0e1022308 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -14,6 +15,23 @@ namespace jit { namespace fuser { namespace cuda { +namespace { +const char* boolLiteral(bool value) { + return value ? "true" : "false"; +} + +std::string varName(const Val* val) { + std::stringstream value_name; + if (val == nullptr) { + value_name << "$nullptr"; + } else { + value_name << val->name(); + } + return value_name.str(); +} + +} // namespace + // Make sure we can inline something, before we attempt to. static void checkInlineable(const Expr* expr) { for (auto input : expr->inputs()) { @@ -49,6 +67,70 @@ void IrPrinter::handle(Fusion* fusion) { } } +void IrPrinter::handle(const kir::Kernel* kernel) { + TORCH_CHECK(kernel != nullptr); + + // kernel declaration + os_ << "\nKERNEL ("; + for (auto in : kernel->inputs()) { + handle(in); + if (in != kernel->inputs().back()) { + os_ << ", "; + } + } + os_ << ") -> ("; + for (auto out : kernel->outputs()) { + handle(out); + if (out != kernel->outputs().back()) { + os_ << ", "; + } + } + os_ << ") :\n"; + + // kernel body + indent_size_++; + for (auto expr : kernel->topLevelExprs()) { + handle(expr); + } + indent_size_--; + os_ << "END.\n\n"; +} + +void IrPrinter::handle(kir::Kernel& kernel) { + handle(&kernel); +} + +void IrPrinter::handleScope(const kir::Scope& scope) { + // Save the uses of the parent scope + indent_size_++; + for (auto expr : scope.exprs()) { + handle(expr); + } + indent_size_--; +} + +void IrPrinter::handle(const IterDomain* id) { + os_ << id->getIterType(); + os_ << id->getParallelType(); + os_ << varName(id); + os_ << "{"; + if (!id->start()->isZeroInt()) { + print_inline(id->start()); + os_ << " : "; + } + if (id->stop() != id->extent()) { + print_inline(id->stop()); + os_ << " : "; + } + print_inline(id->extent()); + os_ << "}"; + if (id->isRFactorProduct()) + os_ << "rf"; + if (id->hasPaddingToMultipleOfWarp()) { + os_ << "_p"; + } +} + void IrPrinter::handle(const TensorDomain* td) { if (td->nDims() == 0) { os_ << "[ 0 ]"; @@ -65,9 +147,9 @@ void IrPrinter::handle(const TensorDomain* td) { void IrPrinter::handle(const TensorView* tv) { if (tv->nDims() == 0) { - os_ << typePrefix(tv->getDataType().value()) << tv->name(); + os_ << typePrefix(tv->getDataType().value()) << varName(tv); } else { - os_ << "T" << tv->name(); + os_ << "T" << varName(tv); switch (tv->getMemoryType()) { case MemoryType::Global: os_ << "_g"; @@ -94,28 +176,6 @@ void IrPrinter::handle(const TensorView* tv) { } } -void IrPrinter::handle(const IterDomain* id) { - os_ << id->getIterType(); - os_ << id->getParallelType(); - os_ << id->name(); - os_ << "{"; - if (!id->start()->isZeroInt()) { - print_inline(id->start()); - os_ << " : "; - } - if (id->stop() != id->extent()) { - print_inline(id->stop()); - os_ << " : "; - } - print_inline(id->extent()); - os_ << "}"; - if (id->isRFactorProduct()) - os_ << "rf"; - if (id->hasPaddingToMultipleOfWarp()) { - os_ << "_p"; - } -} - void IrPrinter::handle(const Bool* b) { if (print_inline_ && b->definition() != nullptr) { os_ << "( "; @@ -124,10 +184,9 @@ void IrPrinter::handle(const Bool* b) { return; } - if (b->isSymbolic()) { - os_ << "b" << b->name(); - } else { - os_ << "bool(" << *(b->value()) << ")"; + os_ << "b" << varName(b); + if (b->isConst()) { + os_ << "(" << (b->value().value() ? "true" : "false") << ")"; } } @@ -140,7 +199,7 @@ void IrPrinter::handle(const Double* d) { } if (d->isSymbolic()) { - os_ << "d" << d->name(); + os_ << "d" << varName(d); } else { os_ << "double(" << std::setprecision( @@ -160,30 +219,20 @@ void IrPrinter::handle(const Int* i) { } if (i->isSymbolic()) { - os_ << "i" << i->name(); + os_ << "i" << varName(i); } else { os_ << *(i->value()); } } -void IrPrinter::handle(const NamedScalar* i) { - os_ << i->name(); -} - -static bool isTV(const Val* val) { - return val->getValType().value() == ValType::TensorView; -} - -// Check if we're a TensorView op that we can generate code for. -static bool isTVOp(const Expr* expr) { - return expr->outputs().size() == 1 && isTV(expr->outputs().front()); +void IrPrinter::handle(const NamedScalar* ns) { + os_ << ns->name(); } void IrPrinter::handle(const UnaryOp* uop) { - bool istvop = isTVOp(uop); + bool istvop = ir_utils::isTvOp(uop); if (!print_inline_) { - indent(); - os_ << uop->out(); + indent() << uop->out(); if (istvop) { os_ << "\n"; indent_size_++; @@ -230,10 +279,9 @@ void IrPrinter::handle(const UnaryOp* uop) { } void IrPrinter::handle(const BinaryOp* bop) { - bool istvop = isTVOp(bop); + bool istvop = ir_utils::isTvOp(bop); if (!print_inline_) { - indent(); - os_ << bop->out(); + indent() << bop->out(); // tensor operations tend to be long, break them up into multiple lines if (istvop) { @@ -286,7 +334,7 @@ void IrPrinter::handle(const BinaryOp* bop) { } void IrPrinter::handle(const TernaryOp* top) { - bool istvop = isTVOp(top); + bool istvop = ir_utils::isTvOp(top); if (!print_inline_) { indent(); os_ << top->out(); @@ -327,18 +375,16 @@ void IrPrinter::handle(const TernaryOp* top) { } void IrPrinter::handle(const ReductionOp* rop) { - indent(); - os_ << rop->out() << " = reduction( " << rop->in() - << ", op = " << rop->getReductionOpType() - << ", initial value = " << rop->init() << " )\n"; + indent() << rop->out() << " = reduction( " << rop->in() + << ", op = " << rop->getReductionOpType() + << ", initial value = " << rop->init() << " )\n"; } void IrPrinter::handle(const WelfordOp* wop) { - indent(); - os_ << wop->outAvg() << "(Avg),\n" - << wop->outVar() << "(Var),\n" - << wop->outN() << "(Count)" - << "\n = Welford ( "; + indent() << wop->outAvg() << "(Avg),\n" + << wop->outVar() << "(Var),\n" + << wop->outN() << "(Count)" + << "\n = Welford ( "; if (wop->singleValue()) { os_ << wop->inAvg() << "(Avg), "; } else { @@ -353,24 +399,48 @@ void IrPrinter::handle(const WelfordOp* wop) { } void IrPrinter::handle(const BroadcastOp* bop) { - indent(); - os_ << bop->out() << " = broadcast( " << bop->in() << " )\n"; + indent() << bop->out() << " = broadcast( " << bop->in() << " )\n"; +} + +void IrPrinter::handle(const Split* s) { + os_ << (s->innerSplit() ? "Split: " : "Outer split: "); + handle(s->in()); + os_ << " by factor " << s->factor() << " -> "; + handle(s->outer()); + os_ << ", "; + handle(s->inner()); + if (s->startOffset()) { + os_ << ", start offset: "; + handle(s->startOffset()); + } + if (s->stopOffset()) { + os_ << ", stop offset: "; + handle(s->stopOffset()); + } + os_ << "\n"; +} + +void IrPrinter::handle(const Merge* m) { + os_ << "Merge: "; + handle(m->outer()); + os_ << " and "; + handle(m->inner()); + os_ << " -> "; + handle(m->out()); + os_ << "\n"; } void IrPrinter::handle(const TransposeOp* top) { - indent(); - os_ << top->out() << " = transpose( " << top->in() << " )\n"; + indent() << top->out() << " = transpose( " << top->in() << " )\n"; } void IrPrinter::handle(const ShiftOp* sop) { - indent(); - os_ << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets() - << "}, padding = " << (sop->pad() ? "true" : "false") << " )\n"; + indent() << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets() + << "}, {" << sop->padWidth() << "} )\n"; } void IrPrinter::handle(const GatherOp* op) { - indent(); - os_ << op->out() << " = gather( " << op->in() << ", {"; + indent() << op->out() << " = gather( " << op->in() << ", {"; bool no_comma = true; for (const auto& s : op->windowShape()) { if (!no_comma) { @@ -392,36 +462,187 @@ void IrPrinter::handle(const GatherOp* op) { } void IrPrinter::handle(const ViewOp* top) { - indent(); - os_ << top->out() << " = view( " << top->in() << " )\n"; + indent() << top->out() << " = view( " << top->in() << " )\n"; } -void IrPrinter::handle(const Split* s) { - os_ << (s->innerSplit() ? "Split: " : "Outer split: "); - handle(s->in()); - os_ << " by factor " << s->factor() << " -> "; - handle(s->outer()); - os_ << ", "; - handle(s->inner()); - if (s->startOffset()) { - os_ << ", start offset: "; - handle(s->startOffset()); +void IrPrinter::handle(const kir::Predicate* node) { + switch (node->predicate_type()) { + case PredicateType::Inline: { + os_ << "Inline_Predicate"; + break; + } + case PredicateType::Manual: { + os_ << node->value(); + break; + } + case PredicateType::Misaligned: { + os_ << "Misaligned_Predicate"; + break; + } + case PredicateType::Padding: { + os_ << "Padding_Predicate"; + break; + } + case PredicateType::Shift: { + os_ << "Shift_Predicate"; + break; + } + case PredicateType::Unswitch: { + os_ << "Unswitch_Predicate"; + break; + } + case PredicateType::Vectorize: { + os_ << "Vectorize_Predicate"; + break; + } + default: + break; } - if (s->stopOffset()) { - os_ << ", stop offset: "; - handle(s->stopOffset()); +} + +void IrPrinter::handle(const kir::TensorIndex* ti) { + os_ << "T" << varName(ti); + switch (ti->view()->getMemoryType()) { + case MemoryType::Global: + os_ << "_g"; + break; + case MemoryType::Shared: + os_ << "_s"; + break; + case MemoryType::Local: + os_ << "_l"; + break; } + os_ << "["; + for (auto index : ti->indices()) { + print_inline(index); + if (index != ti->indices().back()) { + os_ << ", "; + } + } + os_ << "]"; + os_ << " view( T" << varName(ti->view()) << " )"; +} + +void IrPrinter::handle(const kir::Allocate* node) { + indent(); + handle(node->buffer()); + os_ << " = ALLOCATE(" + << "mem_type=" << node->memoryType() << ", " + << "size="; + print_inline(node->size()); + os_ << ", " + << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n"; + if (node->alias() != nullptr) { + indent() << kTab << ".alias="; + handle(node->alias()->buffer()); + os_ << "\n"; + } +} + +void IrPrinter::handle(const kir::Sync* node) { + indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync()) + << ")\n"; +} + +void IrPrinter::handle(const kir::ForLoop* node) { + indent() << "FOR "; + handle(node->index()); + os_ << " in "; + handle(node->iter_domain()); + os_ << ":\n"; + handleScope(node->body()); +} + +void IrPrinter::handle(const kir::IfThenElse* node) { + indent() << "IF "; + handle(node->predicate()); + os_ << ":\n"; + handleScope(node->thenBody()); + if (node->hasElse()) { + indent() << "ELSE:\n"; + handleScope(node->elseBody()); + } +} + +void IrPrinter::handle(const kir::GridBroadcast* node) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} + +void IrPrinter::handle(const kir::GridReduction* node) { + const auto* reduction_op = node->reduction_op(); + indent(); + handle(reduction_op->out()); + os_ << " = " + << "GRID_REDUCTION(op='" << reduction_op->getReductionOpType() << "'" + << ", in="; + handle(reduction_op->in()); + os_ << ", init="; + handle(reduction_op->init()); + os_ << ", pred="; + handle(reduction_op->predicate()); + os_ << ")\n"; + indent() << kTab << ".reduction_buffer="; + handle(node->reduction_buffer()->buffer()); + os_ << "\n"; + indent() << kTab << ".sync_buffer="; + handle(node->sync_buffer()->buffer()); + os_ << "\n"; + indent() << kTab << ".grid_pred="; + handle(node->predicate()); os_ << "\n"; } -void IrPrinter::handle(const Merge* m) { - os_ << "Merge: "; - handle(m->outer()); - os_ << " and "; - handle(m->inner()); - os_ << " -> "; - handle(m->out()); +void IrPrinter::handle(const kir::GridWelford* node) { + const auto* welford_op = node->welford_op(); + indent(); + handle(welford_op->outVar()); + os_ << ","; + handle(welford_op->outAvg()); + os_ << ","; + handle(welford_op->outN()); + os_ << " = " + << "GRID_WELFORD(" + << "inAvg="; + handle(welford_op->inAvg()); + if (!welford_op->inN()->isOneInt()) { + indent() << ", inVar="; + handle(welford_op->inVar()); + } + indent() << ", inN="; + handle(welford_op->inN()); + if (!welford_op->initN()->isZeroInt()) { + indent() << ", initVar="; + handle(welford_op->initVar()); + os_ << " initAvg="; + handle(welford_op->initAvg()); + os_ << " initN="; + handle(welford_op->initN()); + } + indent() << ", pred="; + handle(welford_op->predicate()); + os_ << ")\n"; + indent() << kTab << ".var_buffer="; + handle(node->var_buffer()->buffer()); + os_ << ".avg_buffer="; + handle(node->avg_buffer()->buffer()); + os_ << ".n_buffer="; + handle(node->N_buffer()->buffer()); + os_ << "\n"; + indent() << kTab << ".sync_buffer="; + handle(node->sync_buffer()->buffer()); os_ << "\n"; + indent() << kTab << ".grid_pred="; + handle(node->predicate()); + os_ << "\n"; +} + +void IrPrinter::handle(const kir::InitMagicZero* node) { + indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; +} + +void IrPrinter::handle(const kir::UpdateMagicZero* node) { + indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; } void IrTransformPrinter::handle(Fusion* f) { @@ -450,7 +671,7 @@ void IrTransformPrinter::printTransforms(TensorView* tv) { os() << ")\n"; for (auto exp : all_exp) { - os() << " "; + os() << " "; IrPrinter::handle(exp); } } diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index c080c3f8f99..f8c07886114 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -13,21 +13,30 @@ namespace jit { namespace fuser { namespace cuda { +class Fusion; +namespace kir { +class Kernel; +class Scope; +} // namespace kir + //! Define pretty printing functions for IR nodes //! //! This class is intended for debug printing, so it attempts //! to handle invalid states as well. //! class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { + static constexpr char const* kTab = " "; + public: explicit IrPrinter(std::ostream& os) : os_(os) {} // Indent the generated code - void indent() { + std::ostream& indent() { for (const auto i : c10::irange(indent_size_)) { (void)i; // Suppress unused variable warning os_ << " "; } + return os_; } void resetIndent() { @@ -38,6 +47,8 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { return print_inline_; } + using OptInConstDispatch::handle; + virtual void handle(Fusion* f); // handle calls some non const fusion ops, @@ -52,30 +63,50 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { handle(&f); } - void handle(const Statement* s) override; - void handle(const Val* v) override; - void handle(const Expr* e) override; - - void handle(const TensorDomain*) override; - void handle(const TensorView*) override; - void handle(const IterDomain*) override; - - void handle(const Bool*) override; - void handle(const Double*) override; - void handle(const Int*) override; - void handle(const NamedScalar*) override; - - void handle(const UnaryOp*) override; - void handle(const BinaryOp*) override; - void handle(const TernaryOp*) override; - void handle(const ReductionOp*) override; - void handle(const WelfordOp*) override; - void handle(const BroadcastOp*) override; - void handle(const TransposeOp*) override; - void handle(const ShiftOp*) override; - void handle(const GatherOp*) override; - void handle(const ViewOp*) override; - + virtual void handle(const kir::Kernel* kernel); + virtual void handle(kir::Kernel& kernel); + + void handleScope(const kir::Scope& scope); + + void handle(const Statement* s) final; + void handle(const Val* v) final; + void handle(const Expr* e) final; + + void handle(const IterDomain*) final; + void handle(const TensorDomain*) final; + void handle(const TensorView*) final; + + void handle(const Bool*) final; + void handle(const Double*) final; + void handle(const Int*) final; + void handle(const NamedScalar*) final; + + void handle(const UnaryOp*) final; + void handle(const BinaryOp*) final; + void handle(const TernaryOp*) final; + void handle(const ReductionOp*) final; + void handle(const WelfordOp*) final; + void handle(const BroadcastOp*) final; + void handle(const TransposeOp*) final; + void handle(const ShiftOp*) final; + void handle(const GatherOp*) final; + void handle(const ViewOp*) final; + + void handle(const kir::Predicate*) final; + void handle(const kir::TensorIndex*) final; + + void handle(const kir::GridBroadcast*) final; + void handle(const kir::GridReduction*) final; + void handle(const kir::GridWelford*) final; + void handle(const kir::ForLoop*) final; + void handle(const kir::IfThenElse*) final; + void handle(const kir::Allocate*) final; + void handle(const kir::Sync*) final; + void handle(const kir::InitMagicZero*) final; + void handle(const kir::UpdateMagicZero*) final; + + // IR math printer overrides these to prevent them from printing, keep + // override void handle(const Split*) override; void handle(const Merge*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 1465a88bef3..884b6a6e0ec 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -38,19 +40,19 @@ class ScalarCheck : OptInConstDispatch { } private: - void handle(const Bool* b) override { + void handle(const Bool* b) final { same_ = v1_->as()->sameAs(v2_->as()); } - void handle(const Double* d) override { + void handle(const Double* d) final { same_ = v1_->as()->sameAs(v2_->as()); } - void handle(const Int* i) override { + void handle(const Int* i) final { same_ = v1_->as()->sameAs(v2_->as()); } - void handle(const NamedScalar* ns) override { + void handle(const NamedScalar* ns) final { same_ = v1_->as()->sameAs(v2_->as()); } @@ -70,6 +72,16 @@ bool areEqualScalars(Val* v1, Val* v2) { return ScalarCheck::sameAs(v1, v2); } +Bool::Bool(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DataType::Bool), + maybe_value_{c10::nullopt} {} + +Bool::Bool(IrBuilderPasskey passkey, bool value) + : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {} + +Bool::Bool(IrBuilderPasskey passkey, c10::optional value) + : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {} + Bool::Bool(const Bool* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -87,6 +99,16 @@ bool Bool::sameAs(const Statement* other) const { return false; } +Double::Double(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DataType::Double), + maybe_value_{c10::nullopt} {} + +Double::Double(IrBuilderPasskey passkey, ScalarType value) + : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} + +Double::Double(IrBuilderPasskey passkey, c10::optional value) + : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} + Double::Double(const Double* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -103,6 +125,16 @@ bool Double::sameAs(const Statement* other) const { return false; } +Int::Int(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DataType::Int), + maybe_value_{c10::nullopt} {} + +Int::Int(IrBuilderPasskey passkey, ScalarType value) + : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {} + +Int::Int(IrBuilderPasskey passkey, c10::optional value) + : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {} + Int::Int(const Int* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -120,11 +152,13 @@ bool Int::sameAs(const Statement* other) const { return false; } -UnaryOp::UnaryOp(UnaryOpType type, Val* out, Val* in) - : Expr(ExprType::UnaryOp), unary_op_type_{type}, out_{out}, in_{in} { +UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in) + : Expr(passkey, ExprType::UnaryOp), + unary_op_type_{type}, + out_{out}, + in_{in} { addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) @@ -146,8 +180,13 @@ bool UnaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -BinaryOp::BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs) - : Expr(ExprType::BinaryOp), +BinaryOp::BinaryOp( + IrBuilderPasskey passkey, + BinaryOpType type, + Val* out, + Val* lhs, + Val* rhs) + : Expr(passkey, ExprType::BinaryOp), binary_op_type_{type}, out_{out}, lhs_{lhs}, @@ -155,7 +194,6 @@ BinaryOp::BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs) addOutput(out); addInput(lhs); addInput(rhs); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) @@ -178,8 +216,14 @@ bool BinaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -TernaryOp::TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3) - : Expr(ExprType::TernaryOp), +TernaryOp::TernaryOp( + IrBuilderPasskey passkey, + TernaryOpType type, + Val* out, + Val* in1, + Val* in2, + Val* in3) + : Expr(passkey, ExprType::TernaryOp), ternary_op_type_{type}, out_{out}, in1_{in1}, @@ -189,7 +233,6 @@ TernaryOp::TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3) addInput(in1); addInput(in2); addInput(in3); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) @@ -213,8 +256,12 @@ bool TernaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -BroadcastOp::BroadcastOp(Val* out, Val* in, std::vector is_broadcast_dims) - : Expr(ExprType::BroadcastOp), +BroadcastOp::BroadcastOp( + IrBuilderPasskey passkey, + Val* out, + Val* in, + std::vector is_broadcast_dims) + : Expr(passkey, ExprType::BroadcastOp), out_(out), in_(in), is_broadcast_dims_(std::move(is_broadcast_dims)) { @@ -226,12 +273,18 @@ BroadcastOp::BroadcastOp(Val* out, Val* in, std::vector is_broadcast_dims) auto in_type = in->getValType().value(); TORCH_INTERNAL_ASSERT( - out_type == ValType::TensorView && in_type == ValType::TensorView, + (out_type == ValType::TensorView && in_type == ValType::TensorView) || + (out_type == ValType::TensorIndex && in_type == ValType::TensorIndex), "Cannot braodcast a non-tensor object."); addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); + + if (!out->isA() || !in->isA()) { + return; + } + + passkey.ir_container_->registerExpr(exprPasskey(), this); // This is a generic check that root dims of a consumer and producer match. // Maybe we shouldn't relegate it to this constructor. @@ -294,37 +347,44 @@ bool BroadcastOp::sameAs(const Statement* other) const { } ReductionOp::ReductionOp( + IrBuilderPasskey passkey, BinaryOpType reduction_op_type, Val* init, Val* out, Val* in) - : Expr(ExprType::ReductionOp), + : Expr(passkey, ExprType::ReductionOp), reduction_op_type_(reduction_op_type), init_(init), out_(out), in_(in) { - TORCH_CHECK(out->getValType().value() == ValType::TensorView); + TORCH_CHECK( + out->getValType().value() == ValType::TensorView || + out->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - in->getValType() == ValType::TensorView && - out->getValType() == ValType::TensorView, + (in->getValType() == ValType::TensorView && + out->getValType() == ValType::TensorView) || + (in->getValType() == ValType::TensorIndex && + out->getValType() == ValType::TensorIndex), "Reduction operation was created that does not have tensor inputs and outputs."); - TORCH_INTERNAL_ASSERT( - TensorDomain::noReductions(in->as()->getMaybeRFactorDomain()) - .size() == out->as()->getRootDomain().size(), - "Reduction operation created with mismatched domains."); - + if (in->isA()) { + TORCH_INTERNAL_ASSERT( + TensorDomain::noReductions( + in->as()->getMaybeRFactorDomain()) + .size() == out->as()->getRootDomain().size(), + "Reduction operation created with mismatched domains."); + } TORCH_INTERNAL_ASSERT( init->isConstScalar(), "Tried to create a reduction operation whith an initial value that isn't a constant."); addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } WelfordOp::WelfordOp( + IrBuilderPasskey passkey, Val* out_avg, Val* out_var, Val* out_N, @@ -334,7 +394,7 @@ WelfordOp::WelfordOp( Val* in_avg, Val* in_var, Val* in_N) - : Expr(ExprType::WelfordOp), + : Expr(passkey, ExprType::WelfordOp), out_avg_(out_avg), out_var_(out_var), out_N_(out_N), @@ -345,9 +405,15 @@ WelfordOp::WelfordOp( in_var_(in_var), in_N_(in_N) { // Check output type - TORCH_INTERNAL_ASSERT(out_avg->getValType().value() == ValType::TensorView); - TORCH_INTERNAL_ASSERT(out_var->getValType().value() == ValType::TensorView); - TORCH_INTERNAL_ASSERT(out_N->getValType().value() == ValType::TensorView); + TORCH_INTERNAL_ASSERT( + out_avg->getValType().value() == ValType::TensorView || + out_avg->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT( + out_var->getValType().value() == ValType::TensorView || + out_var->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT( + out_N->getValType().value() == ValType::TensorView || + out_N->getValType().value() == ValType::TensorIndex); // check initial value TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar); @@ -356,22 +422,32 @@ WelfordOp::WelfordOp( // initial value with a count of 1 is un-common enough that I'll push // the responsibility of creating all-zero var tensors to the user TORCH_INTERNAL_ASSERT( - init_avg && init_avg->getValType().value() == ValType::TensorView); + init_avg && + (init_avg->getValType().value() == ValType::TensorView || + init_avg->getValType().value() == ValType::TensorIndex)); TORCH_INTERNAL_ASSERT( - init_var && init_var->getValType().value() == ValType::TensorView); + init_var && + (init_var->getValType().value() == ValType::TensorView || + init_var->getValType().value() == ValType::TensorIndex)); } TORCH_INTERNAL_ASSERT( - in_avg && in_avg->getValType().value() == ValType::TensorView); + in_avg && + (in_avg->getValType().value() == ValType::TensorView || + in_avg->getValType().value() == ValType::TensorIndex), + in_avg->getValType().value()); // check input TORCH_INTERNAL_ASSERT( in_N->getValType().value() == ValType::Scalar || - in_N->getValType().value() == ValType::TensorView); + in_N->getValType().value() == ValType::TensorView || + in_N->getValType().value() == ValType::TensorIndex); if (!in_N->isOneInt()) { // when input is only one value, only the value is required through avg // input the var part is implicitly 0 and codegen will handle that. TORCH_INTERNAL_ASSERT( - in_var && in_var->getValType().value() == ValType::TensorView); + in_var && + (in_var->getValType().value() == ValType::TensorView || + in_var->getValType().value() == ValType::TensorIndex)); } addOutput(out_avg); @@ -384,8 +460,6 @@ WelfordOp::WelfordOp( addInput(in_var); } addInput(in_N); - - name_ = FusionGuard::getCurFusion()->registerExpr(this); } WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) @@ -444,10 +518,11 @@ bool ReductionOp::sameAs(const Statement* other) const { } TransposeOp::TransposeOp( + IrBuilderPasskey passkey, TensorView* out, TensorView* in, std::vector new2old) - : Expr(ExprType::TransposeOp), + : Expr(passkey, ExprType::TransposeOp), out_(out), in_(in), new2old_(std::move(new2old)) { @@ -481,7 +556,6 @@ TransposeOp::TransposeOp( addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) @@ -490,12 +564,17 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), new2old_(src->new2old_) {} -ShiftOp::ShiftOp(Val* out, Val* in, std::vector offsets, bool pad) - : Expr(ExprType::ShiftOp), +ShiftOp::ShiftOp( + IrBuilderPasskey passkey, + Val* out, + Val* in, + std::vector offsets, + std::vector pad_width) + : Expr(passkey, ExprType::ShiftOp), out_(out), in_(in), offsets_(std::move(offsets)), - pad_(pad) { + pad_width_(std::move(pad_width)) { // clang-tidy complains about out_ that it may be null. TORCH_INTERNAL_ASSERT(out_ != nullptr); TORCH_INTERNAL_ASSERT(in_ != nullptr); @@ -514,9 +593,15 @@ ShiftOp::ShiftOp(Val* out, Val* in, std::vector offsets, bool pad) "Invalid offset vector: ", offsets_); + TORCH_INTERNAL_ASSERT( + pad_width_.size() == + TensorDomain::noReductions(in_->as()->getRootDomain()) + .size(), + "Invalid padding width vector: ", + pad_width_); + addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) @@ -524,7 +609,7 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)), offsets_(src->offsets_), - pad_(src->pad_) {} + pad_width_(src->pad_width_) {} bool ShiftOp::sameAs(const Statement* other) const { if (this == other) { @@ -541,11 +626,12 @@ bool ShiftOp::sameAs(const Statement* other) const { } GatherOp::GatherOp( + IrBuilderPasskey passkey, Val* out, Val* in, - std::vector window_shape, - std::vector> pad_width) - : Expr(ExprType::GatherOp), + std::vector window_shape, + std::vector> pad_width) + : Expr(passkey, ExprType::GatherOp), out_(out), in_(in), window_shape_(std::move(window_shape)), @@ -578,28 +664,14 @@ GatherOp::GatherOp( addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) { - std::transform( - src->window_shape_.begin(), - src->window_shape_.end(), - std::back_inserter(window_shape_), - [&ir_cloner](const auto& x) { return ir_cloner->clone(x); }); - for (const auto& pad : src->pad_width_) { - std::vector pad_clone; - std::transform( - pad.begin(), - pad.end(), - std::back_inserter(pad_clone), - [&ir_cloner](const auto& x) { return ir_cloner->clone(x); }); - pad_width_.push_back(pad_clone); - } -} + in_(ir_cloner->clone(src->in_)), + window_shape_(src->window_shape_), + pad_width_(src->pad_width_) {} bool GatherOp::sameAs(const Statement* other) const { if (this == other) { @@ -609,23 +681,10 @@ bool GatherOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (windowShape().size() != other_op->windowShape().size()) { - return false; - } - for (const auto i : c10::irange(windowShape().size())) { - if (!windowShape()[i]->sameAs(other_op->windowShape()[i])) { - return false; - } - } - if (padWidth().size() != other_op->padWidth().size()) { + if (windowShape() != other_op->windowShape() || + padWidth() != other_op->padWidth()) { return false; } - for (const auto i : c10::irange(padWidth().size())) { - if (!padWidth()[i][0]->sameAs(other_op->padWidth()[i][0]) || - !padWidth()[i][1]->sameAs(other_op->padWidth()[i][1])) { - return false; - } - } return Expr::sameAs(other); } @@ -638,11 +697,10 @@ int GatherOp::gatherAxis(int axis) const { return int(windowShape().size()) + axis; } -ViewOp::ViewOp(TensorView* out, TensorView* in) - : Expr(ExprType::ViewOp), out_(out), in_(in) { +ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) + : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) { addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) @@ -651,12 +709,14 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)) {} IterDomain::IterDomain( + IrBuilderPasskey passkey, Val* start, Val* extent, ParallelType parallel_type, IterType iter_type, bool is_rfactor_domain) : IterDomain( + passkey, start, extent, nullptr, @@ -665,16 +725,19 @@ IterDomain::IterDomain( is_rfactor_domain) {} IterDomain::IterDomain( + IrBuilderPasskey passkey, Val* start, Val* extent, Val* stop_offset, ParallelType parallel_type, IterType iter_type, bool is_rfactor_domain) - : Val(ValType::IterDomain, DataType::Int, false), + : Val(passkey, ValType::IterDomain, DataType::Int), start_(start), extent_(extent), - stop_offset_(stop_offset == nullptr ? new Int(0) : stop_offset), + stop_offset_( + stop_offset == nullptr ? passkey.ir_container_->zeroVal() + : stop_offset), parallel_type_(parallel_type), iter_type_(iter_type), is_rfactor_domain_(is_rfactor_domain) { @@ -693,8 +756,6 @@ IterDomain::IterDomain( "Cannot create an iter domain with a start that is not an int but received ", start, " ."); - - name_ = fusion_->registerVal(this); } IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) @@ -729,6 +790,22 @@ bool IterDomain::sameAs(const Statement* other) const { return is_same; } +// Returns a new IterDomain matching properties of this +IterDomain* IterDomain::clone() const { + auto cloned = IrBuilder::create( + ir_container_, + start(), + extent(), + stopOffset(), + getParallelType(), + getIterType(), + isRFactorProduct()); + + cloned->is_padded_dimension_ = is_padded_dimension_; + cloned->padded_to_size_ = padded_to_size_; + return cloned; +} + std::vector IterDomain::clone( const std::vector& domains) { std::vector cloned_domains; @@ -781,14 +858,15 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { itype = IterType::Iteration; } - IterDomain* merged_id = new IterDomain( - new Int(0), + IterDomain* merged_id = IrBuilder::create( + outer->container(), + outer->container()->zeroVal(), merged_id_size->as(), outer->getParallelType(), itype, outer->isRFactorProduct() || inner->isRFactorProduct()); - new Merge(merged_id, outer, inner); + IrBuilder::create(outer->container(), merged_id, outer, inner); return merged_id; } @@ -811,7 +889,8 @@ std::pair IterDomain::split( if (factor->getValType() == ValType::Scalar) { TORCH_CHECK( factor->isConstScalar() || - FusionGuard::getCurFusion()->hasInput(factor), + (FusionGuard::getCurFusion() == factor->fusion() && + factor->isFusionInput()), factor, " is not a constant nor an input. It must be one or the other to be used in a split.", " If you want a symbolic split based on a thread dimension please use IterDomain::split(IterDomain*, ParallelType);"); @@ -832,24 +911,33 @@ std::pair IterDomain::split( in->definition() == nullptr, "Partial split is only allowed with root domains"); } - // outer loop IterDomain - IterDomain* ido = new IterDomain( - new Int(0), + IterDomain* ido = IrBuilder::create( + in->container(), + in->container()->zeroVal(), inner_split ? remainder->as() : factor, in->getParallelType(), in->getIterType(), in->isRFactorProduct()); // inner loop IterDomain - IterDomain* idi = new IterDomain( - new Int(0), + IterDomain* idi = IrBuilder::create( + in->container(), + in->container()->zeroVal(), inner_split ? factor : remainder->as(), in->getParallelType(), in->getIterType(), in->isRFactorProduct()); - new Split(ido, idi, in, factor, inner_split, start_offset, stop_offset); + IrBuilder::create( + in->container(), + ido, + idi, + in, + factor, + inner_split, + start_offset, + stop_offset); return {ido, idi}; } @@ -864,7 +952,9 @@ std::pair IterDomain::split( } std::pair IterDomain::stridedSplit(int factor) { - auto split_out = IterDomain::split(this, new Int(factor), true); + // Use partial split so that only valid values are retained + auto split_out = IterDomain::split( + this, IrBuilder::create(container(), factor), true, true); split_out.second->iter_type_ = IterType::Stride; split_out.first->is_rfactor_domain_ = true; @@ -907,9 +997,10 @@ Val* IterDomain::stop() const { } TensorDomain::TensorDomain( + IrBuilderPasskey passkey, std::vector root_domain, std::vector contiguity) - : Val(ValType::TensorDomain, DataType::Null, false), + : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), contiguity_( contiguity.empty() ? std::vector(root_domain_.size(), false) @@ -925,14 +1016,14 @@ TensorDomain::TensorDomain( has_nontrivial_reduction_ = false; domain_ = root_domain_; resetDomains(); - name_ = fusion_->registerVal(this); } TensorDomain::TensorDomain( + IrBuilderPasskey passkey, std::vector root_domain, std::vector domain, std::vector contiguity) - : Val(ValType::TensorDomain, DataType::Null, false), + : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), domain_(std::move(domain)), contiguity_( @@ -963,15 +1054,15 @@ TensorDomain::TensorDomain( // Just due to clang-tidy, correct value set in resetDomains has_nontrivial_reduction_ = false; resetDomains(); - name_ = fusion_->registerVal(this); } TensorDomain::TensorDomain( + IrBuilderPasskey passkey, std::vector root_domain, std::vector rfactor_domain, std::vector domain, std::vector contiguity) - : Val(ValType::TensorDomain, DataType::Null, false), + : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), domain_(std::move(domain)), rfactor_domain_(std::move(rfactor_domain)), @@ -1013,7 +1104,6 @@ TensorDomain::TensorDomain( // Just due to clang-tidy, correct value set in resetDomains has_nontrivial_reduction_ = false; resetDomains(); - name_ = fusion_->registerVal(this); } TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) @@ -1026,6 +1116,30 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) contiguity_(src->contiguity()), has_nontrivial_reduction_(src->has_nontrivial_reduction_) {} +namespace { +std::vector lowerIterDomains( + const std::vector& domains) { + std::vector lowered_domains; + lowered_domains.reserve(domains.size()); + for (const auto iter_domain : domains) { + lowered_domains.push_back(iter_domain); + } + return lowered_domains; +}; +} // namespace + +bool TensorDomain::hasBlockBroadcast() const { + return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { + return id->isBroadcast() && id->isThreadDim(); + }); +} + +bool TensorDomain::hasGridBroadcast() const { + return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { + return id->isBroadcast() && id->isBlockDim(); + }); +} + bool TensorDomain::operator==(const TensorDomain& other) const { // Checks equality of each class field. Should not be necessary to // check no_bcast_domain_ and no_reduction_domain_ as they are just @@ -1389,6 +1503,7 @@ std::pair TensorDomain::rFactor( } Split::Split( + IrBuilderPasskey passkey, IterDomain* outer, IterDomain* inner, IterDomain* in, @@ -1396,14 +1511,18 @@ Split::Split( bool inner_split, Val* start_offset, Val* stop_offset) - : Expr(ExprType::Split), + : Expr(passkey, ExprType::Split), outer_{outer}, inner_{inner}, in_{in}, factor_{factor}, inner_split_{inner_split}, - start_offset_{start_offset != nullptr ? start_offset : new Int(0)}, - stop_offset_{stop_offset != nullptr ? stop_offset : new Int(0)} { + start_offset_{ + start_offset != nullptr ? start_offset + : passkey.ir_container_->zeroVal()}, + stop_offset_{ + stop_offset != nullptr ? stop_offset + : passkey.ir_container_->zeroVal()} { TORCH_INTERNAL_ASSERT( factor_->isAnInt(), "Attempted to create a Split node with a non-integer factor."); @@ -1412,7 +1531,6 @@ Split::Split( addInput(in); // TODO add factor as an input, need to check Split::Split during validation // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } Split::Split(const Split* src, IrCloner* ir_cloner) @@ -1453,12 +1571,15 @@ bool Split::sameAs(const Statement* other) const { stopOffset()->sameAs(other->as()->stopOffset()); } -Merge::Merge(IterDomain* out, IterDomain* outer, IterDomain* inner) - : Expr(ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} { +Merge::Merge( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* outer, + IterDomain* inner) + : Expr(passkey, ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} { addOutput(out); addInput(outer); addInput(inner); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } Merge::Merge(const Merge* src, IrCloner* ir_cloner) @@ -1477,6 +1598,12 @@ bool Merge::sameAs(const Statement* other) const { return Expr::sameAs(other); } +NamedScalar::NamedScalar( + IrBuilderPasskey passkey, + std::string name, + DataType dtype) + : Val(passkey, ValType::NamedScalar, dtype), name_(std::move(name)) {} + NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner) : Val(src, ir_cloner), name_(src->name_) {} @@ -1495,13 +1622,15 @@ NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { isParallelTypeThread(p_type), "Cannot get parallel dim of non thread type, received: ", p_type); + TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); std::string parallel_dim = stringifyThreadSize(p_type); - return new NamedScalar(parallel_dim, DataType::Int); + return IrBuilder::create(parallel_dim, DataType::Int); } NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) { + TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); std::string parallel_ind = stringifyThread(p_type); - return new NamedScalar(parallel_ind, DataType::Int); + return IrBuilder::create(parallel_ind, DataType::Int); } c10::optional NamedScalar::getParallelDim() const { diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index a2c14386147..91d07b76b80 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 5bf05b0f516..004cfa23dff 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -140,7 +141,8 @@ struct SubstituteInExpr : public OptInDispatch { reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in(); auto out = reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out(); - expr_ = new UnaryOp(unary_expr->getUnaryOpType(), out, in); + expr_ = IrBuilder::create( + unary_expr->container(), unary_expr->getUnaryOpType(), out, in); } void handle(BinaryOp* binary_expr) final { @@ -151,7 +153,12 @@ struct SubstituteInExpr : public OptInDispatch { auto out = reference_->sameAs(binary_expr->out()) ? substitute_ : binary_expr->out(); - expr_ = new BinaryOp(binary_expr->getBinaryOpType(), out, lhs, rhs); + expr_ = IrBuilder::create( + binary_expr->container(), + binary_expr->getBinaryOpType(), + out, + lhs, + rhs); } void handle(TernaryOp* ternary_expr) final { @@ -163,7 +170,13 @@ struct SubstituteInExpr : public OptInDispatch { : ternary_expr->in3(); auto out = reference_->sameAs(ternary_expr->out()) ? substitute_ : ternary_expr->out(); - expr_ = new TernaryOp(ternary_expr->getTernaryOpType(), out, in1, in2, in3); + expr_ = IrBuilder::create( + ternary_expr->container(), + ternary_expr->getTernaryOpType(), + out, + in1, + in2, + in3); } void handle(ReductionOp* reduction_expr) final { @@ -176,8 +189,12 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(reduction_expr->in()) ? substitute_ : reduction_expr->in(); - expr_ = - new ReductionOp(reduction_expr->getReductionOpType(), init, out, in); + expr_ = IrBuilder::create( + reduction_expr->container(), + reduction_expr->getReductionOpType(), + init, + out, + in); } void handle(BroadcastOp* broadcast_expr) final { @@ -187,7 +204,11 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_ : broadcast_expr->in(); - expr_ = new BroadcastOp(out, in, broadcast_expr->getBroadcastDimFlags()); + expr_ = IrBuilder::create( + broadcast_expr->container(), + out, + in, + broadcast_expr->getBroadcastDimFlags()); } void handle(TransposeOp* transpose_expr) final { @@ -201,7 +222,8 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(transpose_expr->in()) ? substitute_->as() : transpose_expr->in(); - expr_ = new TransposeOp(out, in, transpose_expr->new2old()); + expr_ = IrBuilder::create( + transpose_expr->container(), out, in, transpose_expr->new2old()); } void handle(ShiftOp* shift_expr) final { @@ -210,7 +232,12 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in(); - expr_ = new ShiftOp(out, in, shift_expr->offsets(), shift_expr->pad()); + expr_ = IrBuilder::create( + shift_expr->container(), + out, + in, + shift_expr->offsets(), + shift_expr->padWidth()); } void handle(GatherOp* gather_expr) final { @@ -219,8 +246,12 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in(); - expr_ = new GatherOp( - out, in, gather_expr->windowShape(), gather_expr->padWidth()); + expr_ = IrBuilder::create( + gather_expr->container(), + out, + in, + gather_expr->windowShape(), + gather_expr->padWidth()); } void handle(ViewOp* view_expr) final { @@ -234,7 +265,7 @@ struct SubstituteInExpr : public OptInDispatch { auto out = reference_->sameAs(view_expr->out()) ? substitute_->as() : view_expr->out(); - expr_ = new ViewOp(out, in); + expr_ = IrBuilder::create(view_expr->container(), out, in); } void handle(WelfordOp* welford_expr) final { @@ -268,7 +299,8 @@ struct SubstituteInExpr : public OptInDispatch { welford_expr->initN() && reference_->sameAs(welford_expr->initN()) ? substitute_ : welford_expr->initN(); - expr_ = new WelfordOp( + expr_ = IrBuilder::create( + welford_expr->container(), out_avg, out_var, out_N, @@ -402,13 +434,31 @@ std::vector allTvs(Fusion* fusion) { return uniqueEntries({used_tvs.begin(), used_tvs.end()}); } -std::vector historyOf(TensorDomain* td) { - return ExprSort::getExprs( - td->fusion(), {td->domain().begin(), td->domain().end()}); -} - -std::vector historyOf(TensorView* tv) { - return historyOf(tv->domain()); +std::vector getReductionOps(Fusion* fusion) { + std::vector red_ops; + for (auto expr : fusion->exprs()) { + const Val* out_val = nullptr; + if (expr->isA()) { + out_val = expr->as()->out(); + } else if (expr->isA()) { + out_val = expr->as()->outAvg(); + } else { + continue; + } + if (out_val == nullptr || !out_val->isA()) { + continue; + } + auto out_tv = out_val->as(); + if (std::any_of( + out_tv->getRootDomain().begin(), + out_tv->getRootDomain().end(), + [](IterDomain* id) { + return id->isReduction() && !id->isTrivialReduction(); + })) { + red_ops.push_back(expr); + } + } + return red_ops; } } // namespace ir_utils diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index c8dc2e6f679..1bf3f27ec0b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -110,6 +110,9 @@ auto filterByType(InputIt first, InputIt last) { return FilteredView(first, last); } +template +auto filterByType(const ContainerType&& inputs) = delete; + template auto filterByType(const ContainerType& inputs) { return filterByType(inputs.cbegin(), inputs.cend()); @@ -175,11 +178,7 @@ TORCH_CUDA_CU_API std::vector outputTvsOf( // returns all tensor views in fusion that are used between outputs and inputs. TORCH_CUDA_CU_API std::vector allTvs(Fusion* fusion); -// Returns the history of expressions applied to the domains of td -TORCH_CUDA_CU_API std::vector historyOf(TensorDomain* td); - -// Returns the history of expressions applied to the domains of tv -TORCH_CUDA_CU_API std::vector historyOf(TensorView* tv); +TORCH_CUDA_CU_API std::vector getReductionOps(Fusion* fusion); } // namespace ir_utils } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 344df98f5a7..894b40f79e3 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace torch { @@ -31,21 +32,94 @@ void remove_visited( } } +// Return all dependencies of a node including members of the node. +class RecursiveDependencies : public OptInDispatch { + public: + static std::vector next(Statement* stmt) { + RecursiveDependencies find_next(stmt); + return find_next.next_stmts_; + } + + private: + RecursiveDependencies() = default; + + RecursiveDependencies(Statement* stmt) { + handle(stmt); + } + + using OptInDispatch::handle; + + void handle(Expr* expr) final { + FusionGuard::getCurFusion()->assertInContainer( + expr, + "IterVisitor.cpp::RecursiveDependencies::handle(Expr*) Cannot traverse expr, "); + next_stmts_.insert( + next_stmts_.end(), expr->inputs().begin(), expr->inputs().end()); + } + + void handle(Val* val) final { + FusionGuard::getCurFusion()->assertInContainer( + val, + "IterVisitor.cpp::RecursiveDependencies::handle(Val*) Cannot traverse val, "); + OptInDispatch::handle(val); + } + + void simpleVal(Val* val) { + if (val->definition() == nullptr) { + return; + } + next_stmts_.push_back(val->definition()); + } + + void handle(Bool* stmt) final { + simpleVal(stmt); + } + + void handle(Double* stmt) final { + simpleVal(stmt); + } + + void handle(Int* stmt) final { + simpleVal(stmt); + } + + void handle(NamedScalar* stmt) final { + simpleVal(stmt); + } + + void handle(IterDomain* stmt) final { + next_stmts_.push_back(stmt->start()); + next_stmts_.push_back(stmt->extent()); + next_stmts_.push_back(stmt->stopOffset()); + simpleVal(stmt); + } + + void handle(TensorDomain* stmt) final { + next_stmts_.insert( + next_stmts_.end(), stmt->domain().begin(), stmt->domain().end()); + simpleVal(stmt); + } + + void handle(TensorView* tv) final { + next_stmts_.push_back(tv->domain()); + simpleVal(tv); + } + + std::vector next_stmts_; +}; + } // namespace std::vector IterVisitor::next(Statement* stmt) { if (stmt->isVal()) { return next(stmt->as()); - } else if (stmt->isExpr()) { - return next(stmt->as()); } else { - TORCH_INTERNAL_ASSERT( - false, "IterVisitor could not detect type in next_dispatch."); + return next(stmt->as()); } } std::vector IterVisitor::next(Val* v) { - FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, "); + FusionGuard::getCurFusion()->assertInContainer(v, "Cannot traverse val, "); if (v->definition() != nullptr) { return {v->definition()}; } @@ -53,7 +127,8 @@ std::vector IterVisitor::next(Val* v) { } std::vector IterVisitor::next(Expr* expr) { - FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, "); + FusionGuard::getCurFusion()->assertInContainer( + expr, "Cannot traverse expr, "); std::vector next_stmts{ expr->inputs().begin(), expr->inputs().end()}; return next_stmts; @@ -93,7 +168,8 @@ void IterVisitor::handle(Val* v) { void IterVisitor::traverseFrom( Fusion* fusion, const std::vector& from, - bool traverseAllPaths) { + bool traverseAllPaths, + bool traverseIntoMembers) { FusionGuard fg(fusion); std::unordered_set visited; @@ -137,7 +213,8 @@ void IterVisitor::traverseFrom( } else { // We're not ready to process this node, so add all its inputs to be // checked Visit input nodes. - auto next_stmts = next(stmt); + auto next_stmts = + traverseIntoMembers ? RecursiveDependencies::next(stmt) : next(stmt); // We may want to retraverse nodes, in that case revisit everything! if (!traverseAllPaths) { // If we don't want to retraverse, remove nodes we already visisted. @@ -308,7 +385,7 @@ void BackwardVisitor::traverseFrom( auto vals = AllVals::get(fusion, from); - auto exprs = ExprSort::getExprs(fusion, from); + auto exprs = StmtSort::getExprs(fusion, from); { size_t pos = 0; @@ -704,22 +781,41 @@ std::unordered_set DependencyCheck::getAllDependentVals( return DependentVals::getAllDependentVals(of); } -void ExprSort::handle(Expr* expr) { - exprs.push_back(expr); +void StmtSort::handle(Statement* stmt) { + stmts.push_back(stmt); } -std::vector ExprSort::getExprs(Fusion* fusion) { - ExprSort es; - es.traverse(fusion); - return es.exprs; +std::vector StmtSort::getExprs(Fusion* fusion, bool traverse_members) { + auto terminating_outputs = fusion->getTerminatingOutputs(); + return StmtSort::getExprs(fusion, terminating_outputs, traverse_members); } -std::vector ExprSort::getExprs( +std::vector StmtSort::getExprs( Fusion* fusion, - const std::vector& from) { - ExprSort es; - es.traverseFrom(fusion, from, false); - return es.exprs; + const std::vector& from, + bool traverse_members) { + StmtSort es; + es.traverseFrom(fusion, from, false, traverse_members); + auto stmts = StmtSort::getStmts(fusion, from, traverse_members); + auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); + std::vector exprs(filter.begin(), filter.end()); + return exprs; +} + +std::vector StmtSort::getStmts( + Fusion* fusion, + bool traverse_members) { + auto terminating_outputs = fusion->getTerminatingOutputs(); + return StmtSort::getStmts(fusion, terminating_outputs, traverse_members); +} + +std::vector StmtSort::getStmts( + Fusion* fusion, + const std::vector& from, + bool traverse_members) { + StmtSort es; + es.traverseFrom(fusion, from, false, traverse_members); + return es.stmts; } void InputsOf::handle(Val* v) { diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index d4aa56ea2fe..2447933d737 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -83,18 +83,21 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { void traverseHelper(Fusion* fusion, bool traverse_all_paths = false); public: - // Starts at nodes provided in from, traverses from these nodes to inputs. - // Calls handle on all Statement*s in topological sorted order. - // traverseAllPaths = false only call handle on each Statement* once - // traverseAllPaths = true traverses all paths from nodes in from to inputs. - // Handle on a Statement* for every path from "from" nodes, to inputs. - // to argument allows specification of nodes to stop at if we want to stop - // beffore we hit all leaf nodes. This can be helpful when we want to traverse - // from TensorView::domain(), to the rfactor domain, instead of root domain. + //! Starts at nodes provided in from, traverses from these nodes to inputs. + //! Calls handle on all Statement*s in topological sorted order. + //! \param traverseAllPaths = false only call handle on each Statement* once + //! traverseAllPaths = true traverses all paths from nodes in from to + //! inputs. Calls handle on a Statement* for every path from "from" nodes, + //! to inputs. + //! \param traverseIntoMembers = When hitting nodes like TensorView, + //! TensorDomain, or IterDomain where there are members of the nodes that are + //! Val's a value of "true" will also traverse into those member Val's, a + //! value of "false" will not traverse into the members. void traverseFrom( Fusion* fusion, const std::vector& from, - bool traverseAllPaths = false); + bool traverseAllPaths = false, + bool traverseIntoMembers = false); // Iterates from terminating outputs registered with the fusion. Terminating // means value is not used to generate any other value used in producing @@ -246,18 +249,40 @@ class TORCH_CUDA_CU_API DependencyCheck { // Expr sort will take a fusion and return a topologically sorted list of // expressions. -class ExprSort : public IterVisitor { +class StmtSort : public IterVisitor { protected: - std::vector exprs; + std::vector stmts; - void handle(Expr* expr) override; + void handle(Statement* stmt) override; public: - static std::vector getExprs(Fusion* fusion); + // If traverse_members it will also extract all member nodes in the sorted + // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc + static std::vector getExprs( + Fusion* fusion, + bool traverse_members = false); + // If traverse_members it will also extract all member nodes in the sorted + // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc static std::vector getExprs( Fusion* fusion, - const std::vector& from); + const std::vector& from, + bool traverse_members = false); + + // If traverse_members it will also extract all member nodes in the sorted + // statement list in the fusion. i.e. all IterDomains, extents, and associated + // expressions of them + static std::vector getStmts( + Fusion* fusion, + bool traverse_members = false); + + // If traverse_members it will also extract all member nodes in the sorted + // expr list in the fusion. i.e. all IterDomains, extents, and associated + // expressions of them + static std::vector getStmts( + Fusion* fusion, + const std::vector& from, + bool traverse_members = false); }; class InputsOf : public IterVisitor { diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index d3ef9eeb95d..b9062f5bc45 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -1,7 +1,8 @@ #include +#include #include #include -#include +#include #include #include @@ -11,22 +12,24 @@ namespace torch { namespace jit { namespace fuser { namespace cuda { + +IrBuilderPasskey::IrBuilderPasskey(IrContainer* ir_container) + : ir_container_(ir_container) {} + namespace kir { namespace { //! Scan all primary expressions in the Kernel IR and build //! lists of specialized nodes and other interesting information -class KernelIrScanner : private kir::IrVisitor { +class KernelIrScanner : private IrVisitor { public: explicit KernelIrScanner(const Kernel* kernel) { - for (const auto& ir_node : kernel->irNodes()) { - ir_node->accept(this); - } + IrVisitor::handle(kernel->topLevelExprs()); const auto gpu_lower = GpuLower::current(); for (auto split : gpu_lower->nonDivisibleSplitInfo().splitsToValidate()) { - auto extent = gpu_lower->lowerValue(split->in()->extent()); - auto factor = gpu_lower->lowerValue(split->factor()); + auto extent = split->in()->extent(); + auto factor = split->factor(); summary_.splits_to_validate.emplace_back(extent, factor); } } @@ -36,7 +39,17 @@ class KernelIrScanner : private kir::IrVisitor { } private: - void visit(const kir::Sync* sync) final { + using IrVisitor::handle; + void handle(Expr* expr) final { + IrVisitor::handle(expr); + for (auto inp : expr->inputs()) { + handle(inp); + } + for (auto out : expr->outputs()) { + handle(out); + } + } + void handle(Sync* sync) final { // TODO: Move to a dedicated validation pass // which is not on the common execution/compilation path if (sync->isWarHazardSync()) { @@ -44,7 +57,7 @@ class KernelIrScanner : private kir::IrVisitor { } } - void visit(const kir::Allocate* allocate) final { + void handle(Allocate* allocate) final { switch (allocate->memoryType()) { case MemoryType::Global: summary_.global_allocations.push_back(allocate); @@ -65,28 +78,23 @@ class KernelIrScanner : private kir::IrVisitor { } } - void visit(const kir::UnaryOp* unary_op) final { - if (unary_op->operation() == UnaryOpType::RandLike) { + void handle(UnaryOp* unary_op) final { + if (unary_op->getUnaryOpType() == UnaryOpType::RandLike) { // This kernel is using random numbers summary_.is_stochastic = true; } } - void visit(const kir::TensorIndex* tensor_index) final { + void handle(TensorIndex* tensor_index) final { const auto tv = tensor_index->view(); const auto domain = tv->domain(); - // Do we have any reductions? summary_.has_block_reductions = summary_.has_block_reductions || domain->hasBlockReduction(); - // Do we have block broadcasts? - summary_.has_block_broadcasts = - summary_.has_block_broadcasts || domain->hasBlockBroadcast(); - // Update the largest smem data type if (domain->hasBlockReduction() || domain->hasGridReduction() || - tv->memoryType() == MemoryType::Shared) { + tv->getMemoryType() == MemoryType::Shared) { const auto data_type = tv->dtype(); const size_t type_size = dataTypeSize(data_type); if (type_size > max_smem_type_size_) { @@ -94,38 +102,50 @@ class KernelIrScanner : private kir::IrVisitor { summary_.largest_smem_data_type = data_type; } } + } - // Update Welford - if (tensor_index->definition() != nullptr && - tensor_index->definition()->isA()) { - summary_.has_welford = true; - summary_.has_block_welford = - summary_.has_block_welford || domain->hasBlockReduction(); - summary_.has_grid_welford = - summary_.has_grid_welford || domain->hasGridReduction(); - } + void handle(WelfordOp* welford_op) final { + summary_.has_welford = true; + TORCH_INTERNAL_ASSERT(welford_op->outAvg()->isA()); + auto out_dom = welford_op->outAvg()->as()->view()->domain(); + summary_.has_block_welford = + summary_.has_block_welford || out_dom->hasBlockReduction(); } - void visit(const kir::GridWelford* grid_welford) final { - const auto dom = grid_welford->welford_op() - ->out() - ->as() - ->view() - ->domain(); + void handle(GridWelford* grid_welford) final { + summary_.has_welford = true; + summary_.has_grid_welford = true; + const auto dom = + grid_welford->welford_op()->out()->as()->view()->domain(); updateGridReductionInLoop(dom); } - void visit(const kir::GridReduction* grid_reduction) final { + void handle(GridReduction* grid_reduction) final { + summary_.has_grid_reductions = true; const auto dom = grid_reduction->reduction_op() ->out() - ->as() + ->as() ->view() ->domain(); updateGridReductionInLoop(dom); } - void visit(const kir::GridBroadcast*) final { + void handle(GridBroadcast* grid_broadcast) final { summary_.has_cooperative_grid_reduction = true; + handle(grid_broadcast->broadcast_op()); + } + + void handle(BroadcastOp* bop) final { + const ParallelTypeBitmap parallel_types = + GpuLower::current()->threadPredMap().getParallelBroadcastDomains( + bop->out()->as()->view()); + summary_.broadcast_parallel_types.emplace(bop, parallel_types); + // Do we have block broadcasts? + summary_.has_block_broadcasts = + summary_.has_block_broadcasts || parallel_types.hasTID(); + // Do we have grid broadcasts? + summary_.has_grid_broadcasts = + summary_.has_grid_broadcasts || parallel_types.hasBID(); } private: @@ -136,10 +156,9 @@ class KernelIrScanner : private kir::IrVisitor { void updateGridReductionInLoop(TensorDomain* dom) { summary_.has_grid_reductions = true; - const auto gpu_lower = GpuLower::current(); for (const auto i : c10::irange(dom->nDims())) { - const auto id = - gpu_lower->caParallelMap().getConcreteMappedID(dom->domain()[i]); + const auto id = GpuLower::current()->caParallelMap().getConcreteMappedID( + dom->domain()[i]); summary_.has_cooperative_grid_reduction = summary_.has_cooperative_grid_reduction || @@ -169,7 +188,7 @@ class KernelIrScanner : private kir::IrVisitor { //! MemoryType::Global for tensors parallelized with blockIdx), it is //! assumed that allocation is properly extended for the iteration //! count. -class ValidateAllocation : private kir::IrVisitor { +class ValidateAllocation : private OptOutConstDispatch { public: static void validate(const Kernel* kernel) { ValidateAllocation validate_allocation(kernel); @@ -178,14 +197,14 @@ class ValidateAllocation : private kir::IrVisitor { private: explicit ValidateAllocation(const Kernel* kernel) { live_allocations_.emplace_back(std::vector()); - for (const auto& ir_node : kernel->topLevelExprs()) { - ir_node->accept(this); + for (const auto& expr : kernel->topLevelExprs()) { + OptOutConstDispatch::handle(expr); } live_allocations_.pop_back(); TORCH_INTERNAL_ASSERT(live_allocations_.empty()); } - void visit(const kir::Allocate* allocate) final { + void handle(const Allocate* allocate) final { TORCH_INTERNAL_ASSERT(!live_allocations_.empty()); live_allocations_.back().push_back(allocate); } @@ -195,53 +214,52 @@ class ValidateAllocation : private kir::IrVisitor { // during in the allocation lowering if it's thread-parallel and not // allocated on shared or global memories, or if it's block-parallel // ando not allocated on global memory. - void validate(const kir::ForLoop* for_loop) { + void validate(const ForLoop* for_loop) { const auto loop_id = for_loop->iter_domain(); - const auto gpu_lower = GpuLower::current(); for (const auto& allocations : live_allocations_) { for (const auto& allocate : allocations) { - const auto tv = dynamic_cast(allocate->buffer()); + const auto tv = dynamic_cast(allocate->buffer()); if (tv == nullptr) { continue; } for (const auto& axis : tv->domain()->domain()) { - if (!gpu_lower->caParallelMap().areMapped(loop_id, axis)) { + if (!GpuLower::current()->caParallelMap().areMapped(loop_id, axis)) { continue; } - if (isParallelTypeThreadDim(loop_id->parallelType())) { + if (isParallelTypeThreadDim(loop_id->getParallelType())) { TORCH_INTERNAL_ASSERT( - tv->memoryType() == MemoryType::Shared || - tv->memoryType() == MemoryType::Global, + tv->getMemoryType() == MemoryType::Shared || + tv->getMemoryType() == MemoryType::Global, "Tensor t", tv->name(), " must be allocated on SMEM or GMEM."); - } else if (isParallelTypeBlockDim(loop_id->parallelType())) { - TORCH_INTERNAL_ASSERT(tv->memoryType() == MemoryType::Global); + } else if (isParallelTypeBlockDim(loop_id->getParallelType())) { + TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global); } } } } } - void visit(const kir::ForLoop* for_loop) final { + void handle(const ForLoop* for_loop) final { if (for_loop->stop() != for_loop->iter_domain()->extent() && - isParallelTypeThread(for_loop->iter_domain()->parallelType())) { + isParallelTypeThread(for_loop->iter_domain()->getParallelType())) { validate(for_loop); } live_allocations_.emplace_back(std::vector()); for (const auto& expr : for_loop->body().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } live_allocations_.pop_back(); } - void visit(const kir::IfThenElse* ite) final { + void handle(const IfThenElse* ite) final { for (const auto& expr : ite->thenBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } for (const auto& expr : ite->elseBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } @@ -252,11 +270,9 @@ class ValidateAllocation : private kir::IrVisitor { } // namespace // TODO(kir): Kernel IR validation -void Kernel::finalize(std::vector top_level_exprs) { - TORCH_CHECK(top_level_exprs_.empty()); +void Kernel::finalize(std::vector top_level_exprs) { + TORCH_INTERNAL_ASSERT(top_level_exprs_.empty()); top_level_exprs_ = std::move(top_level_exprs); - predicate_map_ = std::make_unique( - GpuLower::current()->threadPredMap()); warp_padded_parallel_info_ = GpuLower::current()->getWarpPaddedParallelInfo(); ValidateAllocation::validate(this); analyze(); @@ -270,8 +286,63 @@ void Kernel::analyze() { } void Kernel::print() const { - kir::IrPrinter ir_printer(std::cout); - ir_printer.printKernel(this); + IrPrinter ir_printer(std::cout); + ir_printer.handle(this); +} + +//! Register the Val with this fusion +void Kernel::registerVal(Val* val) { + if (inContainer(val)) { + return; + } + if (val->kernel()) { + TORCH_CHECK( + val->kernel() == this, + val->toString(), + " was not found in the active kernel."); + } + + Fusion::registerVal(val); +} + +//! Register expr with this fusion. +//! When we register an expression, we want to update the dependency tracking +//! of Vals. We add expr to our general expr_set_, +void Kernel::registerExpr(Expr* expr) { + if (inContainer(expr)) { + return; + } + + if (expr->kernel()) { + TORCH_CHECK( + expr->kernel() == this, + expr->toString(), + " was not found in the active kernel."); + } + + for (Val* input : expr->inputs()) { + TORCH_INTERNAL_ASSERT( + inContainer(input), + "Input\n", + input->toString(), + " to expr,\n", + expr->toString(), + ",\n is invalid because it is not in the same kernel."); + } + + for (Val* output : expr->outputs()) { + TORCH_INTERNAL_ASSERT( + inContainer(output), + "Output\n", + output->toString(), + " to expr,\n", + expr->toString(), + ",\n is invalid because it is not in the same kernel."); + } + + // Register expr is explicitly non-SSA when coming from a kernel. This is + // detected inside Fusion::registerExpr + Fusion::registerExpr(expr); } } // namespace kir diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h index b273324e1e2..0c8bbdef9df 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.h +++ b/torch/csrc/jit/codegen/cuda/kernel.h @@ -1,12 +1,15 @@ #pragma once -#include -#include -#include +#include + +#include +#include +#include #include #include #include +#include #include #include @@ -47,6 +50,9 @@ struct KernelSummary { //! Do we have any block broadcasts? bool has_block_broadcasts = false; + //! Do we have any grid broadcasts? + bool has_grid_broadcasts = false; + //! Do we have any welford op? bool has_welford = false; @@ -67,87 +73,47 @@ struct KernelSummary { std::vector dynamic_lmem_allocations; //! ceilDiv extents that must be divisible - std::vector> splits_to_validate; + std::vector> splits_to_validate; + + //! Effective ParallelTypes of broadcast ops + std::unordered_map + broadcast_parallel_types; }; //! Container for a lowered Kernel IR //! -//! TODO(kir): currently, it is just pointing to nodes owned -//! by a Fusion object. The goal is to have the Kernel object -//! own the Kernel IR nodes -//! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API Kernel final : public NonCopyable { +class TORCH_CUDA_CU_API Kernel final : public Fusion { public: - Kernel() = default; + // Kernel starts by grabbing all the nodes from the provided fusion. + // Kernel is not SSA, if a definition is not set, we should update it, but + // not remove previous definition if it is set. This is primarily because when + // we do something like generate an initialization statement for a reduction + // TV, we may want to continue to do fusion like analysis on the original + // expression. + Kernel(Fusion* fusion) : Fusion(*fusion) {} + + Kernel() = delete; + + // No move or copy semantics + Kernel(const Kernel&) = delete; + Kernel& operator=(const Kernel&) = delete; //! Finalize a kernel definition //! //! At this point we have a complete kernel definition and we can //! run analysis passes to build a KernelSummary //! - void finalize(std::vector top_level_exprs); - - //! Register input as an input of the kernel - void addInput(Val* input) { - inputs_.push_back(input); - input_set_.insert(input); - } + void finalize(std::vector top_level_exprs); - //! Register output as an output of the kernel - void addOutput(Val* output) { - outputs_.push_back(output); - output_set_.insert(output); - } - - const auto& inputs() const { - return inputs_; - } - - const auto& outputs() const { - return outputs_; - } - - bool isInput(Val* val) const { - return input_set_.find(val) != input_set_.end(); - } - - bool isOutput(Val* val) const { - return output_set_.find(val) != output_set_.end(); - } - - const auto& topLevelExprs() const { + const std::vector& topLevelExprs() const { return top_level_exprs_; } - const auto& irNodes() const { - return ir_nodes_; - } - const KernelSummary& summary() const { return summary_; } - const ThreadPredicateMap& predicateMap() const { - return *predicate_map_; - } - - //! Register a new Kernel IR node - //! - //! \note This is a specialized helper for kir::IrBuilder, not - //! intendted for general use - //! - void registerIrNode(kir::Passkey passkey, std::unique_ptr node) { - TORCH_CHECK(passkey.kernel == this); - ir_nodes_.push_back(std::move(node)); - } - - //! Allocates a new value identifier - kir::ValueId newValueId(kir::Passkey passkey) { - TORCH_CHECK(passkey.kernel == this); - return next_value_id_++; - } - //! Checks if parallel type is padded bool isParallelTypePadded(ParallelType ptype) const { return ptype == ParallelType::TIDx && @@ -161,32 +127,26 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable { //! Debug dump of the Kernel IR void print() const; + protected: + //! Register the Val with this fusion + void registerVal(Val* val) override; + + //! Register expr with this fusion. + //! When we register an expression, we want to update the dependency tracking + //! of Vals. We add expr to our general expr_set_, + void registerExpr(Expr* expr) override; + private: // Analyze the kernel IR and caches the summary of interesting data void analyze(); private: - // Kernel IR nodes - std::vector> ir_nodes_; - // Top level statements - std::vector top_level_exprs_; - - // Kernel inputs and outputs - std::vector inputs_; - std::vector outputs_; - std::unordered_set input_set_; - std::unordered_set output_set_; - - // Used to allocate unique value IDs - kir::ValueId next_value_id_ = 1; + std::vector top_level_exprs_; // Summary of interesting kernel data KernelSummary summary_; - // Predicate map - // TODO(kir): consider a simpler, kernel IR based version - std::unique_ptr predicate_map_; WarpPaddedParallelInfo warp_padded_parallel_info_; }; diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index 39350876bd2..c1c113dbbc4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace torch { namespace jit { @@ -25,6 +26,10 @@ int getCommonDeviceCUDA(const at::ArrayRef& inputs) { continue; } const auto& device = input.toTensor().device(); + // skip cpu scalar tensor as they'll be promoted to scalar later + if (device.is_cpu() && is_cpu_scalar(input.toTensor())) { + continue; + } TORCH_CHECK(device.is_cuda(), "nvfuser only supports cuda device"); auto cur_index = device.index(); if (index != -1 && index != cur_index) { @@ -202,9 +207,9 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } // Access kernels associated with the common device id - auto dev_id = getCommonDeviceCUDA(inputs); - TORCH_INTERNAL_ASSERT(dev_id >= 0); - auto& kernel_runtimes = kernel_runtimes_[dev_id]; + auto device_index = getCommonDeviceCUDA(inputs); + TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); + auto& kernel_runtimes = kernel_runtimes_[device_index]; // Check for re-use hit case // a kernel runtime is re-usable if all the compiled @@ -277,14 +282,6 @@ FusionKernelRuntime::FusionKernelRuntime( } else { auto complete_fusion_heuristic = maybe_complete_fusion_heuristic.value(); - // Translate welfords if apply - if (fusion_copy->hasWelford()) { - bool translated = SegmentCandidateFinder::TranslateWelfordInFusion( - fusion_copy.get(), inputs); - if (translated) { - complete_fusion_heuristic = ScheduleHeuristic::Persistent; - } - } // Take ownership of the transformed fusion single_kernel_fusion_ = std::move(fusion_copy); @@ -358,7 +355,7 @@ std::vector FusionKernelRuntime::runKernelWithInput( launch_params = scheduler_entry->pointwiseParams().lparams; } executors_[group_id].compileFusion( - fusion_to_run.get(), options, inputs, launch_params); + fusion_to_run.get(), inputs, launch_params, options); } else { // Load launch params for reduction and normalization kernels if (scheduler_entry->hasReductionParam()) { @@ -453,6 +450,7 @@ std::vector FusionKernelRuntime::runWithInput( " inputs but expecting ", segmented_fusion_->inputs().size()); + c10::Device device(c10::DeviceType::CUDA, 0); int extent_index_ = 0; // Bind input in the tensor_map for (const auto i : c10::irange(inputs.size())) { @@ -466,6 +464,7 @@ std::vector FusionKernelRuntime::runWithInput( // more convenient and safer than replication if (inputs[i].isTensor()) { auto aten_tensor = inputs[i].toTensor(); + device = aten_tensor.device(); for (auto dim_size : aten_tensor.sizes()) { runtime_workspace_.tensor_map.emplace( runtime_workspace_.group_extent_binding_order[extent_index_++], @@ -504,14 +503,30 @@ std::vector FusionKernelRuntime::runWithInput( if (iter != runtime_workspace_.tensor_map.end()) { fusion_outputs.push_back(iter->second); } else { + bool empty_type_check = output->getDataType().has_value() && + output->getDataType().value() == DataType::Float; + + // Only support two cases of empty tensor here, since + // this is hot path. + auto out_tv = output->as(); + + // TODO: should be only one of the two once the "empty" + // definition has been unified throughout the ops. + bool empty_tensor_check = + out_tv->isZeroDim() || out_tv->isEmptyTensor(); + // This is the check for an empty tensor; TORCH_INTERNAL_ASSERT( - output->as()->nDims() == 0 && - output->getDataType().has_value() && - output->getDataType().value() == DataType::Float, + empty_tensor_check && empty_type_check, "Non empty tensor cannot be found at tensor_map in ", __FUNCTION__); - fusion_outputs.emplace_back(at::Tensor()); + + // TODO: would need to clean up this part when + // we have a unified and consistent way to generate + // size-0 tensors. + const auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(device); + fusion_outputs.emplace_back(at::empty({0}, tensor_options)); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index ae84c25e4f2..cba42f99dc4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -7,8 +7,8 @@ #include #include +#include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp index 7421d2e235a..3605f7a4155 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp @@ -1,7 +1,6 @@ #include #include -#include #include @@ -16,11 +15,11 @@ void ExpressionEvaluator::bind( Int::ScalarType concrete_value) { TORCH_CHECK(value->isScalar()); TORCH_CHECK(value->dtype() == DataType::Int); - TORCH_CHECK(!value->isConst(), "Tried to bind to a constant value"); + TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value"); TORCH_CHECK( value->definition() == nullptr, "Tried to bind to a value that is computed in the kernel IR: ", - toString(value), + value->toString(), " with ", concrete_value); known_values_[value] = concrete_value; @@ -41,14 +40,18 @@ void ExpressionEvaluator::bind( c10::optional ExpressionEvaluator::evaluate(const Val* value) { if (precomputed_integers_ && precomputed_integers_->ready()) { - return precomputed_integers_->getMaybeValueFor(value); - } else if (value->isScalar() && value->isConst()) { + if (precomputed_integers_->getMaybeValueFor(value).has_value()) { + return precomputed_integers_->getMaybeValueFor(value); + } + } + + if (value->isScalar() && value->isConst()) { return value->as()->value(); } else { FUSER_PERF_SCOPE("kir::ExpressionEvaluator::evaluate"); - TORCH_CHECK(value->isScalar()); - TORCH_CHECK(value->dtype() == DataType::Int); + TORCH_CHECK(value->isScalar(), value->toString()); + TORCH_CHECK(value->dtype() == DataType::Int, value->toString()); // Is the value known (either explicit binding or memoized)? const auto pre_eval_it = known_values_.find(value); @@ -56,7 +59,7 @@ c10::optional ExpressionEvaluator::evaluate(const Val* value) { return pre_eval_it->second; } - value->accept(this); + OptOutConstDispatch::handle(value); const auto post_eval_it = known_values_.find(value); return post_eval_it != known_values_.end() @@ -74,24 +77,23 @@ void ExpressionEvaluator::print() const { std::cout << "\nEvaluation context\n"; std::cout << "--------------------\n"; for (const auto& kv : known_values_) { - std::cout << toString(kv.first) << " = " << kv.second << "\n"; + std::cout << kv.first->toString() << " = " << kv.second << "\n"; + } + std::cout << "\nPre-computed Values\n"; + if (precomputed_integers_ != nullptr) { + precomputed_integers_->print(); } std::cout << "--------------------\n\n"; } -void ExpressionEvaluator::unhandled(const void*) { - TORCH_INTERNAL_ASSERT( - false, "Kernel IR expression evaluation reached an unsupported node"); -} - -void ExpressionEvaluator::visit(const Int* value) { +void ExpressionEvaluator::handle(const Int* value) { TORCH_INTERNAL_ASSERT(!value->isConst()); if (auto def = value->definition()) { - def->accept(this); + OptOutConstDispatch::handle(def); } } -void ExpressionEvaluator::visit(const NamedScalar* named_scalar) { +void ExpressionEvaluator::handle(const NamedScalar* named_scalar) { const auto& name = named_scalar->name(); for (auto pt : kParallelTypeThreads) { auto pt_val_it = known_parallel_dimensions_.find(pt); @@ -105,10 +107,10 @@ void ExpressionEvaluator::visit(const NamedScalar* named_scalar) { } } -void ExpressionEvaluator::visit(const UnaryOp* unary_op) { +void ExpressionEvaluator::handle(const UnaryOp* unary_op) { const auto in = evaluate(unary_op->in()); if (in.has_value()) { - switch (unary_op->operation()) { + switch (unary_op->getUnaryOpType()) { case UnaryOpType::Neg: known_values_[unary_op->out()] = -*in; break; @@ -121,11 +123,11 @@ void ExpressionEvaluator::visit(const UnaryOp* unary_op) { } } -void ExpressionEvaluator::visit(const BinaryOp* binary_op) { +void ExpressionEvaluator::handle(const BinaryOp* binary_op) { const auto lhs = evaluate(binary_op->lhs()); const auto rhs = evaluate(binary_op->rhs()); if (lhs.has_value() && rhs.has_value()) { - switch (binary_op->operation()) { + switch (binary_op->getBinaryOpType()) { case BinaryOpType::Add: known_values_[binary_op->out()] = *lhs + *rhs; break; diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h index 64791387543..63586857ad8 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h @@ -1,7 +1,9 @@ #pragma once -#include +#include + +#include #include #include @@ -34,7 +36,7 @@ namespace kir { //! } //! ``` //! -class TORCH_CUDA_CU_API ExpressionEvaluator : private IrVisitor { +class TORCH_CUDA_CU_API ExpressionEvaluator : private OptInConstDispatch { public: //! Set a concrete value for a symbolic value void bind(const Val* value, Int::ScalarType concrete_value); @@ -56,11 +58,10 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private IrVisitor { } private: - void unhandled(const void*) final; - void visit(const Int* value) final; - void visit(const NamedScalar* named_scalar) final; - void visit(const UnaryOp* unary_op) final; - void visit(const BinaryOp* binary_op) final; + void handle(const Int* value) final; + void handle(const NamedScalar* named_scalar) final; + void handle(const UnaryOp* unary_op) final; + void handle(const BinaryOp* binary_op) final; private: std::unordered_map known_values_; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index eebfd41729c..5d2eb44f8a8 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -1,8 +1,7 @@ +#include #include #include #include -#include -#include #include #include #include @@ -15,369 +14,52 @@ namespace fuser { namespace cuda { namespace kir { -void Node::print() const { - std::cout << "\n"; - IrPrinter(std::cout).printNode(this); - std::cout << "\n"; -} - -Val::Val(Passkey passkey, DataType dtype) : Node(passkey), dtype_(dtype) { - // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48534 - id_ = passkey.kernel->newValueId(passkey); -} - -namespace { - -// Traverse definition of all values involved in constructing the provided val. -// Check if all values involved are constant values, meaning the provided -// val is also a constant value. -class ConstCheck : IrVisitor { - private: - bool is_const_ = true; - - using IrVisitor::visit; - - void visit(const Bool* b) override { - is_const_ = is_const_ && b->isConst(); - } - - void visit(const Double* d) override { - is_const_ = is_const_ && d->isConst(); - } - - void visit(const Int* i) override { - is_const_ = is_const_ && i->isConst(); - } - - void visit(const NamedScalar* ns) override { - is_const_ = is_const_ && false; - } - - void visit(const Expr* expr) { - for (auto inp : expr->inputs()) { - visit(inp); - } - } - - void visit(const Val* val) { - if (val->definition() != nullptr) { - visit(val->definition()); - } else { - val->accept(this); - } - } - - public: - static bool isConst(const Val* val) { - ConstCheck cc; - cc.visit(val); - return cc.is_const_; - } -}; - -} // namespace - -bool Val::isConstScalar() const { - if (!isScalar()) - return false; - return ConstCheck::isConst(this); -} - -Expr* Expr::parentScope() const { - if (scope()) { - return scope()->owner(); - } else { - return nullptr; - } -} - -NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { - std::string parallel_dim = stringifyThreadSize(p_type); - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - return ir_builder.create(parallel_dim, DataType::Int); -} - -NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) { - std::string parallel_ind = stringifyThread(p_type); - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - return ir_builder.create(parallel_ind, DataType::Int); -} - -c10::optional NamedScalar::getParallelDim() const { - if (stringifyThreadSize(ParallelType::TIDx).compare(name()) == 0) { - return c10::optional(ParallelType::TIDx); - } else if (stringifyThreadSize(ParallelType::TIDy).compare(name()) == 0) { - return c10::optional(ParallelType::TIDy); - } else if (stringifyThreadSize(ParallelType::TIDz).compare(name()) == 0) { - return c10::optional(ParallelType::TIDz); - } else if (stringifyThreadSize(ParallelType::BIDx).compare(name()) == 0) { - return c10::optional(ParallelType::BIDx); - } else if (stringifyThreadSize(ParallelType::BIDy).compare(name()) == 0) { - return c10::optional(ParallelType::BIDy); - } else if (stringifyThreadSize(ParallelType::BIDz).compare(name()) == 0) { - return c10::optional(ParallelType::BIDz); - } - return c10::nullopt; -} - -c10::optional NamedScalar::getParallelIndex() const { - if (stringifyThread(ParallelType::TIDx).compare(name()) == 0) { - return c10::optional(ParallelType::TIDx); - } else if (stringifyThread(ParallelType::TIDy).compare(name()) == 0) { - return c10::optional(ParallelType::TIDy); - } else if (stringifyThread(ParallelType::TIDz).compare(name()) == 0) { - return c10::optional(ParallelType::TIDz); - } else if (stringifyThread(ParallelType::BIDx).compare(name()) == 0) { - return c10::optional(ParallelType::BIDx); - } else if (stringifyThread(ParallelType::BIDy).compare(name()) == 0) { - return c10::optional(ParallelType::BIDy); - } else if (stringifyThread(ParallelType::BIDz).compare(name()) == 0) { - return c10::optional(ParallelType::BIDz); - } - return c10::nullopt; -} - -IterDomain::IterDomain(Passkey passkey, Val* start, Val* extent) - : Val(passkey, DataType::Int), - start_(start), - stop_(extent), - extent_(extent) {} - -IterDomain::IterDomain( - Passkey passkey, - const fuser::cuda::IterDomain* iter_domain) - : Val(passkey, iter_domain->getDataType().value()), - start_(GpuLower::current()->lowerValue(iter_domain->start())), - stop_(GpuLower::current()->lowerValue(iter_domain->stop())), - extent_(GpuLower::current()->lowerValue(iter_domain->extent())), - parallel_type_(iter_domain->getParallelType()), - iter_type_(iter_domain->getIterType()), - is_rfactor_domain_(iter_domain->isRFactorProduct()), - is_simple_(iter_domain->definition() == nullptr), - is_padded_dimension_(iter_domain->hasPaddingToMultipleOfWarp()) { - // preserve the fusion node's name - setName(iter_domain->name()); -} - -//! Note that the parallel dimension, if available, may be different -//! from the actual extent of this IterDomain as the parallel -//! dimension is determined by the largest extent of IterDomains -//! sharing the same loop. -Val* IterDomain::extent() const { - TORCH_INTERNAL_ASSERT(extent_ != nullptr); - return extent_; -} - -TensorDomain::TensorDomain(Passkey passkey, std::vector domain) - : Val(passkey, DataType::Null), root_domain_(std::move(domain)) { - domain_ = root_domain_; - resetDomains(); -} - -TensorDomain::TensorDomain( - Passkey passkey, - const fuser::cuda::TensorDomain* tensor_domain) - : Val(passkey, DataType::Null), contiguity_(tensor_domain->contiguity()) { - // preserve the fusion node's name - setName(tensor_domain->name()); - - const auto lowerIterDomains = - [](const std::vector& domains) { - std::vector lowered_domains; - lowered_domains.reserve(domains.size()); - for (const auto iter_domain : domains) { - lowered_domains.push_back( - GpuLower::current()->lowerValue(iter_domain)->as()); - } - return lowered_domains; - }; - - root_domain_ = lowerIterDomains(tensor_domain->getRootDomain()); - domain_ = lowerIterDomains(tensor_domain->domain()); - no_bcast_domain_ = lowerIterDomains(tensor_domain->noBroadcasts()); - no_reduction_domain_ = lowerIterDomains(tensor_domain->noReductions()); - rfactor_domain_ = lowerIterDomains(tensor_domain->getRFactorDomain()); -} - -bool TensorDomain::hasReduction() const { - return no_reduction_domain_.size() != domain_.size(); -} - -bool TensorDomain::hasBlockReduction() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isReduction() && id->isThreadDim(); - }); -} - -bool TensorDomain::hasGridReduction() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isReduction() && id->isBlockDim(); - }); -} - -bool TensorDomain::hasBlockBroadcast() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isBroadcast() && id->isThreadDim(); - }); -} - -bool TensorDomain::hasGridBroadcast() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isBroadcast() && id->isBlockDim(); - }); -} - -bool TensorDomain::hasBroadcast() const { - return no_bcast_domain_.size() != domain_.size(); -} - -bool TensorDomain::hasRFactor() const { - return !rfactor_domain_.empty(); -} - -bool TensorDomain::hasVectorize() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->parallelType() == ParallelType::Vectorize || - id->parallelType() == ParallelType::MisalignedVectorize; - }); -} - -IterDomain* TensorDomain::axis(int i) const { - TORCH_INTERNAL_ASSERT(i >= 0 && i < int(domain_.size())); - return domain_[i]; -} - -std::vector TensorDomain::noReductions( - const std::vector& td) { - std::vector no_reduction_domains; - for (auto id : td) { - if (!id->isReduction()) { - no_reduction_domains.push_back(id); - } - } - return no_reduction_domains; -} - -std::vector TensorDomain::noBroadcasts( - const std::vector& td) { - std::vector no_broadcast_domains; - for (auto id : td) { - if (!id->isBroadcast()) { - no_broadcast_domains.push_back(id); - } - } - return no_broadcast_domains; -} - -TensorView::TensorView(Passkey passkey, const fuser::cuda::TensorView* tv) - : Val(passkey, tv->getDataType().value()), fuser_tv_(tv) { - setName(tv->name()); - domain_ = GpuLower::current()->lowerValue(tv->domain())->as(); - memory_type_ = tv->getMemoryType(); -} - -TensorView::TensorView( - Passkey passkey, - DataType dtype, - TensorDomain* domain, - MemoryType memory_type) - : Val(passkey, dtype), domain_(domain), memory_type_(memory_type) {} - -UnaryOp::UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in) - : Expr(passkey), operation_(operation), out_(out), in_(in) { - addOutput(out); - addInput(in); -} - -BinaryOp::BinaryOp( - Passkey passkey, - BinaryOpType operation, - Val* out, - Val* lhs, - Val* rhs) - : Expr(passkey), operation_(operation), out_(out), lhs_(lhs), rhs_(rhs) { - addOutput(out); - addInput(lhs); - addInput(rhs); -} - -TernaryOp::TernaryOp( - Passkey passkey, - TernaryOpType operation, - Val* out, - Val* in1, - Val* in2, - Val* in3) - : Expr(passkey), - operation_(operation), - out_(out), - in1_(in1), - in2_(in2), - in3_(in3) { - addOutput(out); - addInput(in1); - addInput(in2); - addInput(in3); -} - -ReductionOp::ReductionOp( - Passkey passkey, - BinaryOpType operation, - Val* init, - Val* out, - Val* in) - : Expr(passkey), operation_(operation), init_(init), out_(out), in_(in) { - addOutput(out); - addInput(in); +Predicate::Predicate( + IrBuilderPasskey passkey, + PredicateType ptype, + const Expr* expr, + Bool* thread_pred) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(ptype), + expr_(expr), + thread_pred_(thread_pred) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT( + ptype != PredicateType::Unswitch && ptype != PredicateType::Manual); } -WelfordOp::WelfordOp( - Passkey passkey, - Val* out_var, - Val* out_avg, - Val* out_N, - Val* init_var, - Val* init_avg, - Val* init_N, - Val* in_var, - Val* in_avg, - Val* in_N) - : Expr(passkey), - out_var_(out_var), - out_avg_(out_avg), - out_N_(out_N), - init_var_(init_var), - init_avg_(init_avg), - init_N_(init_N), - in_var_(in_var), - in_avg_(in_avg), - in_N_(in_N) { - addOutput(out_avg); - addOutput(out_var); - addOutput(out_N); - - if (!in_N->isOneInt()) { - addInput(in_var); - } - addInput(in_avg); - addInput(in_N); +Predicate::Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(PredicateType::Unswitch), + unrolled_loop_(unrolled_loop) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr); } -BroadcastOp::BroadcastOp(Passkey passkey, Val* out, Val* in) - : Expr(passkey), out_(out), in_(in) { - TORCH_CHECK(in->isA() || in->isA()); - TORCH_CHECK(out->isA() || out->isA()); - addOutput(out); - addInput(in); +Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(PredicateType::Manual), + value_(value) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT(value != nullptr); } TensorIndex::TensorIndex( - Passkey passkey, - const fuser::cuda::TensorView* view, + IrBuilderPasskey passkey, + const TensorView* view, std::vector indices) - : Val(passkey, view->getDataType().value()), - view_(GpuLower::current()->lowerValue(view)->as()), + : Val(passkey, ValType::TensorIndex, view->getDataType().value()), + view_(view), indices_(indices) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); TORCH_INTERNAL_ASSERT( std::all_of( indices.begin(), @@ -392,20 +74,33 @@ TensorIndex::TensorIndex( indices_.end()); // If indices becomes empty, just put one ZeroInt if (indices_.empty()) { - indices_.push_back(kir::IrBuilder(GpuLower::current()->kernel()).zeroVal()); + indices_.push_back(FusionGuard::getCurFusion()->zeroVal()); } } -Sync::Sync(Passkey passkey, bool war_sync) - : Expr(passkey), war_sync_(war_sync) {} +Sync::Sync(IrBuilderPasskey passkey, bool war_sync) + : Expr(passkey, ExprType::Sync), war_sync_(war_sync) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} -InitMagicZero::InitMagicZero(Passkey passkey) : Expr(passkey) {} +InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) + : Expr(passkey, ExprType::InitMagicZero) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} -UpdateMagicZero::UpdateMagicZero(Passkey passkey) : Expr(passkey) {} +UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) + : Expr(passkey, ExprType::UpdateMagicZero) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); - expr->setScope(this); } void Scope::insert_before(Expr* ref, Expr* expr) { @@ -440,11 +135,6 @@ void Scope::insert(size_t pos, Expr* expr) { void Scope::erase(std::vector::const_iterator pos) { // Remove the scope of the expr if this is the scope auto expr = *pos; - TORCH_INTERNAL_ASSERT( - expr->scope() == this, - "Inconsistent scoping of expression detected: ", - kir::toString(expr)); - expr->setScope(nullptr); exprs_.erase(pos); } @@ -470,7 +160,7 @@ void Scope::clear() { } ForLoop::ForLoop( - Passkey passkey, + IrBuilderPasskey passkey, IterDomain* iter_domain, Val* index, Val* start, @@ -479,7 +169,7 @@ ForLoop::ForLoop( bool vectorize, Val* vectorize_shift, bool unroll_required) - : Expr(passkey), + : Expr(passkey, ExprType::ForLoop), iter_domain_{iter_domain}, index_(index), start_(start), @@ -489,43 +179,42 @@ ForLoop::ForLoop( vectorize_shift_(vectorize_shift), unroll_required_(unroll_required), body_(this) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int); addInput(index); addInput(iter_domain); if (start_ == nullptr && iter_domain->isThread()) { - start_ = - IrBuilder(GpuLower::current()->kernel()) - .create( - stringifyThread(iter_domain->parallelType()), DataType::Int); + start_ = NamedScalar::getParallelIndex(iter_domain->getParallelType()); } if (step_ == nullptr) { if (iter_domain->isThread()) { - step_ = IrBuilder(GpuLower::current()->kernel()) - .create( - stringifyThreadSize(iter_domain->parallelType()), - DataType::Int); + step_ = NamedScalar::getParallelDim(iter_domain->getParallelType()); } else { - step_ = IrBuilder(GpuLower::current()->kernel()).oneVal(); + step_ = FusionGuard::getCurFusion()->oneVal(); } } } -ForLoop::ForLoop(Passkey passkey, IterDomain* iter_domain) +ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) : ForLoop( passkey, iter_domain, - iter_domain->isBroadcast() - ? IrBuilder(GpuLower::current()->kernel()).zeroVal() - : IrBuilder(GpuLower::current()->kernel()) - .create(c10::nullopt), + iter_domain->isBroadcast() ? FusionGuard::getCurFusion()->zeroVal() + : IrBuilder::create(c10::nullopt), nullptr, nullptr, nullptr, - isParallelTypeVectorize(iter_domain->parallelType()), + isParallelTypeVectorize(iter_domain->getParallelType()), nullptr, - false) {} + false) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} -ForLoop::ForLoop(Passkey passkey, const ForLoop* other) +ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) : ForLoop( passkey, other->iter_domain(), @@ -535,7 +224,11 @@ ForLoop::ForLoop(Passkey passkey, const ForLoop* other) other->step(), other->vectorize(), other->vectorize_shift(), - other->isUnrollRequired()) {} + other->isUnrollRequired()) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast @@ -550,7 +243,7 @@ bool ForLoop::isUnrolled() const { if (isUnrollRequired() && !isUnrollable()) { TORCH_WARN( "Unroll required but not possible. Register allocation disabled. Loop index: ", - kir::toString(index_)); + index_->toString()); return false; } @@ -570,7 +263,7 @@ bool ForLoop::isUnrolled() const { } // Unrolling is technically possible but avoided - if (iter_domain()->parallelType() == ParallelType::Unswitch) { + if (iter_domain()->getParallelType() == ParallelType::Unswitch) { // Use ParallelType::Unroll if unrolling is desired. Note that // unswitched size-one loops are not unrolled as they are not // materialized as actual for-loops. @@ -605,8 +298,8 @@ Val* ForLoop::step() const { return step_; } -IfThenElse::IfThenElse(Passkey passkey, Predicate* cond) - : Expr(passkey), then_body_(this), else_body_(this) { +IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) + : Expr(passkey, ExprType::IfThenElse), then_body_(this), else_body_(this) { setPredicate(cond); addInput(cond); } @@ -621,17 +314,19 @@ Val* TensorIndex::index(int i) const { } Allocate::Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, std::vector shape, bool zero_init) - : Expr(passkey), + : Expr(passkey, ExprType::Allocate), buffer_(buffer), memory_type_(memory_type), shape_(std::move(shape)), zero_init_(zero_init) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); if (!shape_.empty()) { TORCH_INTERNAL_ASSERT( (shape_.size() == 1 && shape_[0]->isOneInt()) || @@ -639,7 +334,7 @@ Allocate::Allocate( } else { TORCH_INTERNAL_ASSERT(buffer_->isA()); TORCH_INTERNAL_ASSERT( - buffer_->as()->memoryType() == memory_type_); + buffer_->as()->getMemoryType() == memory_type_); const auto domain = buffer_->as()->domain(); for (auto axis : domain->noReductions()) { shape_.push_back(axis->extent()); @@ -650,19 +345,19 @@ Allocate::Allocate( if (size_ == nullptr) { size_ = s; } else { - size_ = ir_builder.mulExpr(size_, s); + size_ = IrBuilder::mulExpr(size_, s); } } if (size_ == nullptr) { - size_ = ir_builder.oneVal(); + size_ = FusionGuard::getCurFusion()->oneVal(); } addInput(size_); } Allocate::Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, Val* size, @@ -672,31 +367,57 @@ Allocate::Allocate( buffer, memory_type, size == nullptr ? std::vector{} : std::vector{size}, - zero_init) {} + zero_init) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} GridReduction::GridReduction( - Passkey passkey, + IrBuilderPasskey passkey, ReductionOp* reduction_op, Allocate* reduction_buffer, Allocate* sync_buffer) - : Expr(passkey), + : Expr(passkey, ExprType::GridReduction), reduction_op_(reduction_op), reduction_buffer_(reduction_buffer), - sync_buffer_(sync_buffer) {} + sync_buffer_(sync_buffer) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +GridBroadcast::GridBroadcast( + IrBuilderPasskey passkey, + BroadcastOp* broadcast_op, + Allocate* broadcast_buffer, + Allocate* sync_buffer) + : Expr(passkey, ExprType::GridBroadcast), + broadcast_op_(broadcast_op), + broadcast_buffer_(broadcast_buffer), + sync_buffer_(sync_buffer) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} GridWelford::GridWelford( - Passkey passkey, + IrBuilderPasskey passkey, WelfordOp* welford_op, Allocate* var_buffer, Allocate* avg_buffer, Allocate* n_buffer, Allocate* sync_buffer) - : Expr(passkey), + : Expr(passkey, ExprType::GridWelford), welford_op_(welford_op), var_buffer_(var_buffer), avg_buffer_(avg_buffer), n_buffer_(n_buffer), - sync_buffer_(sync_buffer) {} + sync_buffer_(sync_buffer) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} } // namespace kir } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index c1ac6052783..ad6be90bf98 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -1,16 +1,13 @@ #pragma once -#include -#include - -// TODO(kir): remove these once the Kernel IR is separated from Fusion IR +#include #include -#include -#include #include +#include +#include +#include #include -#include #include #include @@ -21,26 +18,22 @@ namespace torch { namespace jit { namespace fuser { namespace cuda { -namespace kir { -class IrBuilder; -class Kernel; +class IrBuilderPasskey; // Abstract nodes -class Node; class Val; class Expr; // Values -class NamedScalar; -class Predicate; class Bool; class Double; class Int; +class NamedScalar; + class IterDomain; class TensorDomain; class TensorView; -class TensorIndex; // Expressions class UnaryOp; @@ -50,7 +43,14 @@ class ReductionOp; class WelfordOp; class BroadcastOp; -// Statements +namespace kir { +class Kernel; + +// Values +class Predicate; +class TensorIndex; + +// Expressions class Allocate; class Sync; class InitMagicZero; @@ -64,443 +64,17 @@ class GridWelford; // Expr container class Scope; -using ValueId = int32_t; - -//! Token used to restrict the access to Kernel IR creation -//! -//! A token is associated with a kernel, which is passed with the key -//! (Passkey::kernel) -//! -//! It is a "granular friendship" token, used to implement the "passkey" idiom: -//! https://www.spiria.com/en/blog/desktop-software/passkey-idiom-and-better-friendship-c -//! https://arne-mertz.de/2016/10/passkey-idiom -//! -class Passkey { - friend class IrBuilder; - - public: - Kernel* const kernel = nullptr; - - private: - explicit Passkey(Kernel* kernel) : kernel(kernel) {} -}; - -//! Kernel IR visitor interface -class TORCH_CUDA_CU_API IrVisitor : public PolymorphicBase { - public: - // TODO(kir): use Node* instead of void* - virtual void unhandled(const void* node) {} - - // Values - virtual void visit(const NamedScalar* named_scalar) { - unhandled(named_scalar); - } - virtual void visit(const Predicate* value) { - unhandled(value); - } - virtual void visit(const Bool* value) { - unhandled(value); - } - virtual void visit(const Double* value) { - unhandled(value); - } - virtual void visit(const Int* value) { - unhandled(value); - } - virtual void visit(const IterDomain* iter_domain) { - unhandled(iter_domain); - } - virtual void visit(const TensorDomain* tensor_domain) { - unhandled(tensor_domain); - } - virtual void visit(const TensorView* tensor_view) { - unhandled(tensor_view); - } - virtual void visit(const TensorIndex* tensor_index) { - unhandled(tensor_index); - } - - // Expressions - virtual void visit(const UnaryOp* node) { - unhandled(node); - } - virtual void visit(const BinaryOp* node) { - unhandled(node); - } - virtual void visit(const TernaryOp* node) { - unhandled(node); - } - virtual void visit(const ReductionOp* node) { - unhandled(node); - } - virtual void visit(const WelfordOp* node) { - unhandled(node); - } - virtual void visit(const BroadcastOp* node) { - unhandled(node); - } - - // Statements - virtual void visit(const Allocate* node) { - unhandled(node); - } - virtual void visit(const Sync* node) { - unhandled(node); - } - virtual void visit(const InitMagicZero* node) { - unhandled(node); - } - virtual void visit(const UpdateMagicZero* node) { - unhandled(node); - } - virtual void visit(const ForLoop* node) { - unhandled(node); - } - virtual void visit(const IfThenElse* node) { - unhandled(node); - } - virtual void visit(const GridReduction* node) { - unhandled(node); - } - virtual void visit(const GridBroadcast* node) { - unhandled(node); - } - virtual void visit(const GridWelford* node) { - unhandled(node); - } -}; - -//! Kernel IR visitor interface -class TORCH_CUDA_CU_API MutableIrVisitor : public PolymorphicBase { - public: - // TODO(kir): use Node* instead of void* - virtual void unhandled(const void*) {} - - // Values - virtual void visit(NamedScalar* named_scalar) { - unhandled(named_scalar); - } - virtual void visit(Predicate* value) { - unhandled(value); - } - virtual void visit(Bool* value) { - unhandled(value); - } - virtual void visit(Double* value) { - unhandled(value); - } - virtual void visit(Int* value) { - unhandled(value); - } - virtual void visit(IterDomain* iter_domain) { - unhandled(iter_domain); - } - virtual void visit(TensorDomain* tensor_domain) { - unhandled(tensor_domain); - } - virtual void visit(TensorView* tensor_view) { - unhandled(tensor_view); - } - virtual void visit(TensorIndex* tensor_index) { - unhandled(tensor_index); - } - - // Expressions - virtual void visit(UnaryOp* node) { - unhandled(node); - } - virtual void visit(BinaryOp* node) { - unhandled(node); - } - virtual void visit(TernaryOp* node) { - unhandled(node); - } - virtual void visit(ReductionOp* node) { - unhandled(node); - } - virtual void visit(BroadcastOp* node) { - unhandled(node); - } - - virtual void visit(WelfordOp* node) { - unhandled(node); - } - - // Statements - virtual void visit(Allocate* node) { - unhandled(node); - } - virtual void visit(Sync* node) { - unhandled(node); - } - virtual void visit(InitMagicZero* node) { - unhandled(node); - } - virtual void visit(UpdateMagicZero* node) { - unhandled(node); - } - virtual void visit(ForLoop* node) { - unhandled(node); - } - virtual void visit(IfThenElse* node) { - unhandled(node); - } - virtual void visit(GridReduction* node) { - unhandled(node); - } - virtual void visit(GridBroadcast* node) { - unhandled(node); - } - virtual void visit(GridWelford* node) { - unhandled(node); - } -}; - -//! Base class for Kernel IR nodes -class TORCH_CUDA_CU_API Node : public NonCopyable, public PolymorphicBase { - public: - explicit Node(Passkey) {} - - //! IR Visitor double-dispatch interface - //! (https://en.wikipedia.org/wiki/Visitor_pattern) - virtual void accept(IrVisitor* visitor) const = 0; - - //! Non constant IR Visitor - virtual void accept(MutableIrVisitor* visitor) = 0; - - //! Debug helper, prints the textual representation of an IR node - void print() const; -}; - -//! Generic value (scalar or tensor) -class TORCH_CUDA_CU_API Val : public Node { - public: - Val(Passkey passkey, DataType dtype); - - // TODO(kir): consider renaming - StmtNameType name() const { - return name_; - } - - void setName(StmtNameType name) { - name_ = name; - } - - ValueId id() const { - return id_; - } - - DataType dtype() const { - return dtype_; - } - - Expr* definition() const { - return definition_; - } - - void setDefinition(Expr* expr) { - // TODO(kir): extra checks on changing existing definitions? - definition_ = expr; - } - - virtual bool isScalar() const { - return false; - } - - bool isConstScalar() const; - - virtual bool isConst() const { - return false; - } - - // TODO(kir): revisit and find a better interface - virtual bool isZeroInt() const { - return false; - } - - virtual bool isOneInt() const { - return false; - } - - void setEvaluatorIndex(int to) { - TORCH_INTERNAL_ASSERT(evaluator_index_ == -1); - evaluator_index_ = to; - } - - int evaluatorIndex() const { - return evaluator_index_; - } - - private: - const DataType dtype_; - - // The expression which defines this value, or nullptr - Expr* definition_ = nullptr; - - // This is a value name preserved from the Fusion IR (optional) - StmtNameType name_ = kInvalidStmName; - - // All Kernel IR values have IDs (unique within the same Kernel) - ValueId id_ = -1; - - // Expr evaluator idx; - int evaluator_index_ = -1; -}; - -//! Base class for expressions and statements -//! -//! Expressions consume inputs and produce outputs (depending on the context -//! this may imply assignments). Currently some of the expressions -//! don't actually produce any outputs (ForLoop, IfThenElse) and they -//! model statements to be executed. -//! -//! TODO(kir): split the expressions, assignments and statements? -//! -class TORCH_CUDA_CU_API Expr : public Node { - public: - explicit Expr(Passkey passkey) : Node(passkey) {} - - const auto& inputs() const { - return inputs_; - } - - const auto& outputs() const { - return outputs_; - } - - Scope* scope() const { - return scope_; - } - - //! Set the current scope - void setScope(Scope* scope) { - scope_ = scope; - } - - Expr* parentScope() const; - - Predicate* predicate() const { - return predicate_; - } - - void setPredicate(Predicate* predicate) { - predicate_ = predicate; - } - - Predicate* writePredicate() const { - return write_predicate_; - } - - void setWritePredicate(Predicate* write_predicate) { - write_predicate_ = write_predicate; - } - - protected: - // TODO(kir): try to avoid this protected interface - void addInput(Val* input) { - inputs_.push_back(input); - } - - void addOutput(Val* output) { - output->setDefinition(this); - outputs_.push_back(output); - } - - private: - // TODO(kir): can we avoid this? - std::vector inputs_; - std::vector outputs_; - - // TODO(kir): revisit scope/nesting data structures - Scope* scope_ = nullptr; - - Predicate* predicate_ = nullptr; - // Only used for reduction-related expressions - Predicate* write_predicate_ = nullptr; -}; - -class TORCH_CUDA_CU_API NamedScalar final : public Val { - public: - // NOLINTNEXTLINE(modernize-pass-by-value) - NamedScalar(Passkey passkey, std::string name, DataType dtype) - : Val(passkey, dtype), name_(name) {} - - explicit NamedScalar(Passkey passkey, const fuser::cuda::NamedScalar* node) - : Val(passkey, node->getDataType().value()) { - name_ = node->name(); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - // TODO(kir): this is hiding and redefining Val::name() - const std::string& name() const { - return name_; - } - - // Return the named scalar extent of a parallel dimension (e.g. blockDim.x) - static NamedScalar* getParallelDim(ParallelType p_type); - - // Return the named scalar index of a parallel dimension (e.g. threadIdx.x) - static NamedScalar* getParallelIndex(ParallelType p_type); - - // Return the parallel type of this NamedScalar if it is an extent of a - // parallel dimension - c10::optional getParallelDim() const; - - // Return the parallel type of this NamedScalar if it is an index of a - // parallel dimension - c10::optional getParallelIndex() const; - - private: - std::string name_; -}; - class TORCH_CUDA_CU_API Predicate final : public Val { public: explicit Predicate( - Passkey passkey, + IrBuilderPasskey passkey, PredicateType ptype, const Expr* expr = nullptr, - Bool* thread_pred = nullptr) - : Val(passkey, DataType::Bool), - ptype_(ptype), - expr_(expr), - thread_pred_(thread_pred) { - TORCH_INTERNAL_ASSERT( - ptype != PredicateType::Unswitch && ptype != PredicateType::Manual); - } + Bool* thread_pred = nullptr); - explicit Predicate(Passkey passkey, ForLoop* unrolled_loop) - : Val(passkey, DataType::Bool), - ptype_(PredicateType::Unswitch), - unrolled_loop_(unrolled_loop) { - TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr); - } - - explicit Predicate(Passkey passkey, Bool* value) - : Val(passkey, DataType::Bool), - ptype_(PredicateType::Manual), - value_(value) { - TORCH_INTERNAL_ASSERT(value != nullptr); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } + explicit Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop); - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit Predicate(IrBuilderPasskey passkey, Bool* value); PredicateType predicate_type() const { return ptype_; @@ -543,6 +117,10 @@ class TORCH_CUDA_CU_API Predicate final : public Val { value_ = value; } + bool isConst() const final { + return hasValue() && value_->isConst(); + } + private: PredicateType ptype_ = PredicateType::Manual; @@ -561,603 +139,13 @@ class TORCH_CUDA_CU_API Predicate final : public Val { Bool* value_ = nullptr; }; -class TORCH_CUDA_CU_API Bool final : public Val { - public: - explicit Bool(Passkey passkey, const c10::optional& value) - : Val(passkey, DataType::Bool), maybe_value_(value) {} - - explicit Bool(Passkey passkey, const fuser::cuda::Bool* node) - : Val(passkey, DataType::Bool), maybe_value_(node->value()) { - setName(node->name()); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - bool isConst() const override { - return maybe_value_.has_value(); - } - - c10::optional value() const { - return maybe_value_; - } - - private: - const c10::optional maybe_value_; -}; - -class TORCH_CUDA_CU_API Double final : public Val { - public: - using ScalarType = double; - - explicit Double(Passkey passkey, const c10::optional& value) - : Val(passkey, DataType::Double), maybe_value_(value) {} - - explicit Double(Passkey passkey, const fuser::cuda::Double* node) - : Val(passkey, DataType::Double), maybe_value_(node->value()) { - setName(node->name()); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - bool isConst() const override { - return maybe_value_.has_value(); - } - - c10::optional value() const { - return maybe_value_; - } - - private: - const c10::optional maybe_value_; -}; - -class TORCH_CUDA_CU_API Int final : public Val { - public: - using ScalarType = int64_t; - - explicit Int(Passkey passkey, const c10::optional& value) - : Val(passkey, DataType::Int), maybe_value_(value) {} - - // SFINAE constructor to avoid 0 constant pointer ambiguity - template < - typename T, - typename = typename std::enable_if< - std::is_pointer::value && - std::is_convertible::value>::type> - explicit Int(Passkey passkey, T node) - : Val(passkey, DataType::Int), maybe_value_(node->value()) { - setName(node->name()); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - bool isConst() const override { - return maybe_value_.has_value(); - } - - bool isZeroInt() const override { - return maybe_value_.has_value() && *maybe_value_ == 0; - } - - bool isOneInt() const override { - return maybe_value_.has_value() && *maybe_value_ == 1; - } - - c10::optional value() const { - return maybe_value_; - } - - private: - const c10::optional maybe_value_; -}; - -class TORCH_CUDA_CU_API IterDomain final : public Val { - public: - IterDomain(Passkey passkey, Val* start, Val* extent); - - explicit IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isReduction() const { - return iterType() == IterType::Reduction; - } - - bool isRFactorProduct() const { - return is_rfactor_domain_; - } - - bool isBroadcast() const { - return iterType() == IterType::BroadcastWithStride || - iterType() == IterType::BroadcastWithoutStride; - } - - bool isGather() const { - return iterType() == IterType::Gather; - } - - bool isStride() const { - return iterType() == IterType::Stride; - } - - bool isParallelized() const { - return parallelType() != ParallelType::Serial; - } - - // Return if this iter domain is mapped to a grid dimension - bool isBlockDim() const { - return parallelType() == ParallelType::BIDz || - parallelType() == ParallelType::BIDy || - parallelType() == ParallelType::BIDx; - } - - // Return if this iter domain is mapped to a block dimension - bool isThreadDim() const { - return parallelType() == ParallelType::TIDz || - parallelType() == ParallelType::TIDy || - parallelType() == ParallelType::TIDx; - } - - // Return if this iter domain is either mapped to a block or grid dimension - bool isThread() const { - return isBlockDim() || isThreadDim(); - } - - ParallelType parallelType() const { - return parallel_type_; - } - - IterType iterType() const { - return iter_type_; - } - - Val* start() const { - return start_; - } - - Val* stop() const { - return stop_; - } - - Val* extent() const; - - bool isSimple() const { - return is_simple_; - } - - bool hasPaddingToMultipleOfWarp() const { - return is_padded_dimension_; - } - - private: - Val* const start_ = nullptr; - Val* const stop_ = nullptr; - Val* const extent_ = nullptr; - ParallelType parallel_type_ = ParallelType::Serial; - IterType iter_type_ = IterType::Iteration; - bool is_rfactor_domain_ = false; - - // An IterDomain is "simple" if the original Fusion IterDomain - // doesn't have a definition ("definition" expression) - // - // TODO(kir): this feels like a hack, revisit - // - bool is_simple_ = true; - - //! Indicates if this iterdomain is a padded parallel dimension - bool is_padded_dimension_ = false; -}; - -// TODO(kir): is this really a value? -class TORCH_CUDA_CU_API TensorDomain final : public Val { - public: - explicit TensorDomain(Passkey, std::vector domain); - - explicit TensorDomain( - Passkey passkey, - const fuser::cuda::TensorDomain* tensor_domain); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - std::vector::size_type nDims() const { - return domain_.size(); - } - - // TODO(kir): rename this - const std::vector& domain() const { - return domain_; - } - - const std::vector& contiguity() const { - return contiguity_; - } - - std::string getContiguityString() const { - std::stringstream ss; - for (auto b : contiguity()) { - ss << (b ? "t" : "f"); - } - return ss.str(); - } - - bool hasReduction() const; - bool hasBlockReduction() const; - bool hasGridReduction() const; - bool hasBlockBroadcast() const; - bool hasGridBroadcast() const; - bool hasBroadcast() const; - bool hasRFactor() const; - bool hasVectorize() const; - - const std::vector& noReductions() const { - return no_reduction_domain_; - } - - const std::vector& noBroadcasts() const { - return no_bcast_domain_; - } - - const std::vector& rootDomain() const { - return root_domain_; - }; - - const std::vector& rfactorDomain() const { - return rfactor_domain_; - }; - - void resetDomains() { - no_reduction_domain_ = noReductions(domain_); - no_bcast_domain_ = noBroadcasts(domain_); - } - - IterDomain* axis(int i) const; - - // TODO(kir): overloading non-static and static methods is not a good idea - static std::vector noReductions(const std::vector&); - static std::vector noBroadcasts(const std::vector&); - - private: - std::vector root_domain_; - std::vector domain_; - std::vector no_bcast_domain_; - std::vector no_reduction_domain_; - std::vector rfactor_domain_; - const std::vector contiguity_; -}; - -class TORCH_CUDA_CU_API TensorView final : public Val { - public: - explicit TensorView(Passkey, const fuser::cuda::TensorView* tv); - - TensorView( - Passkey, - DataType dtype, - TensorDomain* domain, - MemoryType memory_type); - - TensorDomain* domain() const { - return domain_; - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - MemoryType memoryType() const { - return memory_type_; - } - - fuser::cuda::TensorView* fuserTv() const { - TORCH_INTERNAL_ASSERT(fuser_tv_ != nullptr); - // TODO(kir): remove the need for const_cast - return const_cast(fuser_tv_); // NOLINT - } - - private: - TensorDomain* domain_ = nullptr; - MemoryType memory_type_ = MemoryType::Local; - - // TODO(kir): remove temporary hack - const fuser::cuda::TensorView* fuser_tv_ = nullptr; -}; - -class TORCH_CUDA_CU_API UnaryOp final : public Expr { - public: - UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in() const { - return in_; - } - - UnaryOpType operation() const { - return operation_; - } - - private: - const UnaryOpType operation_; - Val* const out_ = nullptr; - Val* const in_ = nullptr; -}; - -class TORCH_CUDA_CU_API BinaryOp final : public Expr { - public: - BinaryOp( - Passkey passkey, - BinaryOpType operation, - Val* out, - Val* lhs, - Val* rhs); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* lhs() const { - return lhs_; - } - - Val* rhs() const { - return rhs_; - } - - BinaryOpType operation() const { - return operation_; - } - - private: - const BinaryOpType operation_; - Val* const out_ = nullptr; - Val* const lhs_ = nullptr; - Val* const rhs_ = nullptr; -}; - -class TORCH_CUDA_CU_API TernaryOp final : public Expr { - public: - TernaryOp( - Passkey passkey, - TernaryOpType operation, - Val* out, - Val* in1, - Val* in2, - Val* in3); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in1() const { - return in1_; - } - - Val* in2() const { - return in2_; - } - - Val* in3() const { - return in3_; - } - - TernaryOpType operation() const { - return operation_; - } - - private: - const TernaryOpType operation_; - Val* const out_ = nullptr; - Val* const in1_ = nullptr; - Val* const in2_ = nullptr; - Val* const in3_ = nullptr; -}; - -class TORCH_CUDA_CU_API ReductionOp final : public Expr { - public: - ReductionOp( - Passkey passkey, - BinaryOpType operation, - Val* init, - Val* out, - Val* in); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in() const { - return in_; - } - - Val* init() const { - return init_; - } - - BinaryOpType operation() const { - return operation_; - } - - private: - const BinaryOpType operation_; - Val* const init_ = nullptr; - Val* const out_ = nullptr; - Val* const in_ = nullptr; -}; - -class TORCH_CUDA_CU_API WelfordOp final : public Expr { - public: - WelfordOp( - Passkey passkey, - Val* out_var, - Val* out_avg, - Val* out_N, - Val* init_var, - Val* init_avg, - Val* init_N, - Val* in_var, - Val* in_avg, - Val* in_N); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_avg_; - } - - Val* in() const { - return in_avg_; - } - - // Welford Specific accessors - // Almost wanted to add a new struct for {var, avg, N} - Val* outVar() const { - return out_var_; - } - - Val* outAvg() const { - return out_avg_; - } - - Val* outN() const { - return out_N_; - } - - Val* initVar() const { - return init_var_; - } - - Val* initAvg() const { - return init_avg_; - } - - Val* initN() const { - return init_N_; - } - - Val* inVar() const { - return in_var_; - } - - Val* inAvg() const { - return in_avg_; - } - - Val* inN() const { - return in_N_; - } - - private: - Val* const out_var_; - Val* const out_avg_; - Val* const out_N_; - Val* const init_var_; - Val* const init_avg_; - Val* const init_N_; - Val* const in_var_; - Val* const in_avg_; - Val* const in_N_; -}; - class TORCH_CUDA_CU_API TensorIndex final : public Val { public: TensorIndex( - Passkey, + IrBuilderPasskey, const fuser::cuda::TensorView* view, std::vector indices); - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - std::vector::size_type nDims() const { return indices_.size(); } @@ -1170,8 +158,7 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { TensorView* view() const { TORCH_INTERNAL_ASSERT(view_ != nullptr); - // TODO(kir): remove the need for const_cast - return const_cast(view_); // NOLINT + return const_cast(view_); // NOLINT } private: @@ -1179,46 +166,17 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { std::vector indices_; }; -class TORCH_CUDA_CU_API BroadcastOp final : public Expr { - public: - BroadcastOp(Passkey passkey, Val* out, Val* in); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in() const { - return in_; - } - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; -}; - //! Allocate is a lower level Node that describes a buffer of memory that //! is required as an intermediate within a kernel. The extent is the expression //! of the size of the buffer that is generated from the TensorView that //! describes the output of an operation. -//! -//! TODO(kir): The components of Allocate like Type and Name could be separated -//! from the the assocated TensorView. Perhaps that is more appropriate? -//! class TORCH_CUDA_CU_API Allocate final : public Expr { public: //! Allocation of a multi-dimensional buffer //! //! param shape Size of each dimension explicit Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, std::vector shape = {}, @@ -1228,20 +186,12 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { //! //! param size Size of allocation explicit Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, Val* size, bool zero_init = false); - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - Val* buffer() const { return buffer_; } @@ -1292,15 +242,7 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { // class TORCH_CUDA_CU_API Sync final : public Expr { public: - explicit Sync(Passkey passkey, bool war_sync = false); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit Sync(IrBuilderPasskey passkey, bool war_sync = false); bool isWarHazardSync() const { return war_sync_; @@ -1315,30 +257,14 @@ class TORCH_CUDA_CU_API Sync final : public Expr { // in helpers.cu class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: - explicit InitMagicZero(Passkey passkey); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit InitMagicZero(IrBuilderPasskey passkey); }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero // in helpers.cu class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: - explicit UpdateMagicZero(Passkey passkey); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit UpdateMagicZero(IrBuilderPasskey passkey); }; // TODO(kir): promote to IR node @@ -1377,7 +303,6 @@ class TORCH_CUDA_CU_API Scope { void push_back(Expr* e) { exprs_.push_back(e); - e->setScope(this); } // Erase expr at pos @@ -1425,7 +350,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! //! TODO: cleaner way to set options? ForLoop( - Passkey passkey, + IrBuilderPasskey passkey, IterDomain* iter_domain, Val* index, Val* start, @@ -1435,17 +360,9 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { Val* vectorize_shift, bool unroll_required); - ForLoop(Passkey passkey, IterDomain* iter_domain); - - ForLoop(Passkey passkey, const ForLoop* other); + ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain); - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + ForLoop(IrBuilderPasskey passkey, const ForLoop* other); Val* index() const { return index_; @@ -1465,6 +382,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { return iter_domain_; } + // TODO: Return pointer instead of reference to be more consistent Scope& body() { return body_; } @@ -1524,15 +442,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: - explicit IfThenElse(Passkey passkey, Predicate* cond); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); Scope& thenBody() { return then_body_; @@ -1567,16 +477,8 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { //! reduction and sync buffers. class TORCH_CUDA_CU_API GridReduction final : public Expr { public: - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - GridReduction( - Passkey passkey, + IrBuilderPasskey passkey, ReductionOp* reduction_op, Allocate* reduction_buffer, Allocate* sync_buffer); @@ -1620,23 +522,11 @@ class TORCH_CUDA_CU_API GridReduction final : public Expr { //! broadcast and sync buffers. class TORCH_CUDA_CU_API GridBroadcast final : public Expr { public: - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - GridBroadcast( - Passkey passkey, + IrBuilderPasskey passkey, BroadcastOp* broadcast_op, Allocate* broadcast_buffer, - Allocate* sync_buffer) - : Expr(passkey), - broadcast_op_(broadcast_op), - broadcast_buffer_(broadcast_buffer), - sync_buffer_(sync_buffer){}; + Allocate* sync_buffer); BroadcastOp* broadcast_op() const { return broadcast_op_; @@ -1665,16 +555,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { //! reduction and sync buffers. class TORCH_CUDA_CU_API GridWelford final : public Expr { public: - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - GridWelford( - Passkey passkey, + IrBuilderPasskey passkey, WelfordOp* welford_op, Allocate* var_buffer, Allocate* avg_buffer, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h b/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h deleted file mode 100644 index 17a095baf12..00000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { -namespace kir { - -//! Kernel IR builder interface -//! -//! The only way to create new Kernel IR nodes is through the -//! kir::IrBuilder interface. An IrBuilder instance is attached to a -//! particular Kernel instance and it provides methods for creating -//! single nodes (kir::IrBuilder::create()) or basic composite expressions -//! (ex. kir::IrBuilder::addExpr()). -//! -//! If the Kernel object is readily available, an IrBuilder can be "wrapped" -//! around it directly: -//! -//! kir::IrBuilder ir_builder(kernel); -//! -//! During lowering, another option is to create an IrBuilder for the -//! kernel that is being created: -//! -//! kir::IrBuilder ir_builder(GpuLower::current()->kernel()); -//! -//! Once we have an IR builder instance, creating nodes looks like: -//! -//! auto new_node = ir_builder.create(1)); -//! auto result = ir_builder.mulExpr(lhs, rhs); -//! -class TORCH_CUDA_CU_API IrBuilder { - public: - explicit IrBuilder(Kernel* kernel) : kernel_(kernel) {} - - //! Allocate a new Kernel IR node, forwarding the arguments - //! to the appropriate constructor - template - T* create(Args&&... args) { - const kir::Passkey passkey(kernel_); - const auto node = new T(passkey, std::forward(args)...); - kernel_->registerIrNode(passkey, std::unique_ptr(node)); - return node; - } - - // Unary operations - Val* negExpr(Val* val); - Val* notExpr(Val* val); - Val* setExpr(Val* val); - Val* setExprNamedScalar(const std::string& name, Val* val); - Val* addressExprNamedScalar(const std::string& name, Val* val); - - // Binary operations - Val* andExpr(Val* lhs, Val* rhs); - Val* eqExpr(Val* lhs, Val* rhs); - Val* gtExpr(Val* lhs, Val* rhs); - Val* ltExpr(Val* lhs, Val* rhs); - Val* leExpr(Val* lhs, Val* rhs); - Val* geExpr(Val* lhs, Val* rhs); - Val* addExpr(Val* lhs, Val* rhs); - Val* subExpr(Val* lhs, Val* rhs); - Val* mulExpr(Val* lhs, Val* rhs); - Val* divExpr(Val* lhs, Val* rhs); - Val* ceilDivExpr(Val* lhs, Val* rhs); - Val* modExpr(Val* lhs, Val* rhs); - Val* maxExpr(Val* lhs, Val* rhs); - Val* minExpr(Val* lhs, Val* rhs); - - // Ternary operations - Val* whereExpr(Val* pred, Val* lhs, Val* rhs); - - // Shortcuts for frequently used vals - Int* zeroVal(); - Int* oneVal(); - Bool* falseVal(); - Bool* trueVal(); - - NamedScalar* magicZeroVal(); - - private: - Val* newResult(DataType dtype); - Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs); - Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs); - - private: - // Non-owning pointer to the kernel to be modified - Kernel* kernel_ = nullptr; - // Frequently used constant vals - Int* zero_ = nullptr; - Int* one_ = nullptr; - Bool* false_ = nullptr; - Bool* true_ = nullptr; - - // Magic zero corresponds to runtime/helpers.cu magic_zero - NamedScalar* magic_zero_ = nullptr; -}; - -//! A wrapper builder with static expression simplification -//! -//! Example: -//! - addExpr(new Int(1), new Int(2)) -> Int(3) -//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo") -//! -//! Designed to be used to simplify predicate and index expressions in -//! generated code. Also, the shift validation may fail without -//! this simplification. -class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { - public: - explicit SimplifyingIrBuilder(Kernel* kernel) : IrBuilder(kernel) {} - - Val* negExpr(Val* val); - Val* notExpr(Val* val); - - Val* addExpr(Int* lhs, Int::ScalarType rhs); - Val* addExpr(Int* lhs, Int* rhs); - Val* addExpr(Val* lhs, Val* rhs); - Val* subExpr(Val* lhs, Val* rhs); - Val* andExpr(Val* lhs, Val* rhs); -}; - -} // namespace kir -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp new file mode 100644 index 00000000000..bfc4794e299 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp @@ -0,0 +1,180 @@ +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace kir { +std::vector IrVisitor::handle(const std::vector& exprs) { + exprs_ = std::vector(exprs); + for (auto expr : exprs) { + handle(expr); + } + return exprs_; +} + +void IrVisitor::handle(ForLoop* fl) { + for_loops_.push_back(fl); + scope_.push_back(&fl->body()); + auto body_exprs = std::vector(fl->body().exprs()); + for (auto expr : body_exprs) { + handle(expr); + } + scope_.pop_back(); + for_loops_.pop_back(); +} + +void IrVisitor::handle(IfThenElse* ite) { + scope_.push_back(&ite->thenBody()); + auto then_exprs = std::vector(ite->thenBody().exprs()); + for (auto expr : then_exprs) { + handle(expr); + } + scope_.pop_back(); + + scope_.push_back(&ite->elseBody()); + auto else_exprs = std::vector(ite->elseBody().exprs()); + for (auto expr : else_exprs) { + handle(expr); + } + scope_.pop_back(); +} + +std::vector ExprMutator::mutate(bool reverse_order) { + if (insertions_.empty() && replacements_.empty()) { + return exprs_; + } + + auto run_insertion = [&](MutationInformation info) { + if (info.scope == nullptr) { + // If reference is nullptr and there are no expressions, simply insert the + // expr + if (exprs_.empty() && info.reference == nullptr) { + exprs_.push_back(info.new_expr); + return; + } + auto pos_it = std::find(exprs_.begin(), exprs_.end(), info.reference); + TORCH_INTERNAL_ASSERT( + pos_it != exprs_.end(), + "Issue finding reference expression for insertion."); + if (info.mode == MutationMode::BEFORE) { + exprs_.insert(pos_it, info.new_expr); + } else { + exprs_.insert(pos_it + 1, info.new_expr); + } + } else { + // If reference is nullptr and there are no expressions, simply insert the + // expr + if (info.scope->exprs().empty() && info.reference == nullptr) { + info.scope->push_back(info.new_expr); + return; + } + if (info.mode == MutationMode::BEFORE) { + info.scope->insert_before(info.reference, info.new_expr); + } else { + info.scope->insert_after(info.reference, info.new_expr); + } + } + }; + + if (reverse_order) { + for (auto it = insertions_.rbegin(); it != insertions_.rend(); ++it) { + run_insertion(*it); + } + } else { + for (auto insertion_info : insertions_) { + run_insertion(insertion_info); + } + } + + for (auto replacement_info : replacements_) { + if (replacement_info.scope == nullptr) { + auto pos_it = + std::find(exprs_.begin(), exprs_.end(), replacement_info.reference); + TORCH_INTERNAL_ASSERT( + pos_it != exprs_.end(), + "Issue finding reference expression for replacement."); + exprs_.insert(pos_it, replacement_info.new_expr); + // iterator can be invalidated from insertion + pos_it = + std::find(exprs_.begin(), exprs_.end(), replacement_info.reference); + exprs_.erase(pos_it); + } else { + replacement_info.scope->insert_before( + replacement_info.reference, replacement_info.new_expr); + replacement_info.scope->erase(replacement_info.reference); + } + } + + insertions_.clear(); + replacements_.clear(); + + return exprs_; +} + +std::vector ExprMutator::traverseAndInsert( + const std::vector& exprs, + bool reverse_order) { + IrVisitor::handle(exprs); + return mutate(reverse_order); +} + +void ExprMutator::registerMutation( + Expr* reference, + Expr* new_expr, + Scope* scope, + MutationMode mode) { + MutationInformation mutation; + mutation.reference = reference; + mutation.new_expr = new_expr; + mutation.scope = scope; + mutation.mode = mode; + if (mode == MutationMode::BEFORE || mode == MutationMode::AFTER) { + insertions_.push_back(mutation); + } else { + replacements_.push_back(mutation); + } +} + +void ExprMutator::registerInsertBefore( + Expr* reference, + Expr* new_expr, + Scope* scope) { + registerMutation(reference, new_expr, scope, MutationMode::BEFORE); +} + +void ExprMutator::registerInsertAfter( + Expr* reference, + Expr* new_expr, + Scope* scope) { + registerMutation(reference, new_expr, scope, MutationMode::AFTER); +} + +void ExprMutator::registerReplace( + Expr* reference, + Expr* new_expr, + Scope* scope) { + registerMutation(reference, new_expr, scope, MutationMode::REPLACE); +} + +void ExprMutator::registerInsertBefore(Expr* reference, Expr* new_expr) { + Scope* scope = scope_.empty() ? nullptr : scope_.back(); + registerInsertBefore(reference, new_expr, scope); +} + +void ExprMutator::registerInsertAfter(Expr* reference, Expr* new_expr) { + Scope* scope = scope_.empty() ? nullptr : scope_.back(); + registerInsertAfter(reference, new_expr, scope); +} + +void ExprMutator::registerReplace(Expr* reference, Expr* new_expr) { + Scope* scope = scope_.empty() ? nullptr : scope_.back(); + registerReplace(reference, new_expr, scope); +} + +} // namespace kir +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h new file mode 100644 index 00000000000..2140498af14 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h @@ -0,0 +1,118 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class Expr; + +namespace kir { +class Predicate; +class TensorIndex; +class ForLoop; +class IfThenElse; +class Scope; + +// Base visitor class that visits all nodes in provided vector. +// +// Includes visiting through scopes like IfThenElse and ForLoop, and tracks +// them in scopes_ and for_loops_. +// +// Makes a copy of exprs at exprs_ which could be used to modify and return. +// +// When traversing through ITE/FLs it will use a copy +// of the provided expressions to make it safe to insert/delete nodes. +// +// Provides a simple base class to inherit from for typical lowering passes on +// Expr list +class TORCH_CUDA_CU_API IrVisitor : public OptOutDispatch { + public: + std::vector handle(const std::vector& expr); + + protected: + using OptOutDispatch::handle; + + virtual void handle(ForLoop*) override; + virtual void handle(IfThenElse*) override; + + protected: + std::vector for_loops_; + std::vector scope_; + std::vector exprs_; +}; + +// Base Expr Mutator class that visits all nodes with IrVisitor, and then +// inserts new expressions or replaces expressions based on insertion/replace +// maps provided. These replacement maps are expected to accumulate during an +// initial traversal, then runs an insertion based on them after the overloaded +// traversal. +// +// Order of mutations may be important, mutations are ordered according to the +// following rules: +// Before/After insertions are ordered as registered when reverse_order == +// false, +// +// Before/After insertions are in reverse order as registered when +// reverse_order == true, +// +// Before/After insertions are done before Expr replacements, so reference for +// insertions must be on pre-replaced Exprs +// +// To place in a scope that is empty, simply provide a nullptr reference +// Since insertions are done in order, it's possible to insert an expression in +// an empty scope, and then use that inserted scope as a reference for +// subsequent mutations. +class ExprMutator : public IrVisitor { + protected: + std::vector traverseAndInsert( + const std::vector& expr, + bool reverse_order = false); + + std::vector mutate(bool reverse_order = false); + + using IrVisitor::handle; + // Registration function which *don't* need to be called "in place" during + // visiting. + void registerInsertBefore(Expr* reference, Expr* new_expr, Scope* scope); + void registerInsertAfter(Expr* reference, Expr* new_expr, Scope* scope); + void registerReplace(Expr* reference, Expr* new_expr, Scope* scope); + + // Registration function which need to be called "in place" during visiting. + // I.E. + // if you want to insert before/after or replace an Expr, you must register + // when in handle(Expr*) of that expr. + void registerInsertBefore(Expr* reference, Expr* new_expr); + void registerInsertAfter(Expr* reference, Expr* new_expr); + void registerReplace(Expr* reference, Expr* new_expr); + + private: + enum class MutationMode { BEFORE, AFTER, REPLACE }; + + void registerMutation( + Expr* ref, + Expr* new_expr, + Scope* scope, + MutationMode mode); + + struct MutationInformation { + Expr* reference = nullptr; + Expr* new_expr = nullptr; + Scope* scope = nullptr; + MutationMode mode = MutationMode::BEFORE; + }; + + // Track insertions as they're registered + std::vector insertions_; + + // Track replacements as they're registered + std::vector replacements_; +}; + +} // namespace kir +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp deleted file mode 100644 index e00da31423c..00000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp +++ /dev/null @@ -1,451 +0,0 @@ -#include -#include - -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { -namespace kir { - -namespace { - -const char* boolLiteral(bool value) { - return value ? "true" : "false"; -} - -std::string varName(const kir::Val* val, const char* prefix) { - std::stringstream value_name; - if (val == nullptr) { - value_name << "$nullptr"; - } else if (val->name() != kInvalidStmName) { - value_name << prefix << val->name(); - } else { - value_name << "k" << prefix << val->id(); - } - return value_name.str(); -} - -} // namespace - -void IrPrinter::printNode(const kir::Node* node) { - os_ << gen(node, true); -} - -void IrPrinter::printKernel(const Kernel* kernel) { - TORCH_CHECK(kernel != nullptr); - - // kernel declaration - os_ << "\nKERNEL ("; - for (auto in : kernel->inputs()) { - os_ << gen(in); - if (in != kernel->inputs().back()) { - os_ << ", "; - } - } - os_ << ") -> ("; - for (auto out : kernel->outputs()) { - os_ << gen(out); - if (out != kernel->outputs().back()) { - os_ << ", "; - } - } - os_ << ") :\n"; - - // kernel body - startBlock(); - for (auto expr : kernel->topLevelExprs()) { - os_ << gen(expr, true); - } - endBlock(); - os_ << "END.\n\n"; -} - -std::ostream& IrPrinter::indent() { - for (const auto i : c10::irange(indent_level_)) { - (void)i; // Suppress unused variable warning - ir_str_ << kTab; - } - ir_str_ << margin_; - return ir_str_; -} - -std::string IrPrinter::gen(const kir::Node* node, bool top_level) { - if (node == nullptr) { - return "$nullptr"; - } - - // If we're generatign a top level statement we expect to start - // with an empty set of uses - TORCH_INTERNAL_ASSERT(!implicit_definition_ || uses_.empty() || !top_level); - - // Mark the node as generated - visited_.insert(node); - - // Generate the node itself - std::stringstream node_str; - std::swap(node_str, ir_str_); - node->accept(this); - std::swap(node_str, ir_str_); - - if (!implicit_definition_) { - return node_str.str(); - } - - if (top_level) { - // Implicitly mark top level nodes as used, so we - // get their definitions printed (useful for debugging) - if (auto val = dynamic_cast(node)) { - uses_.insert(val); - } - - // Make a copy of the node uses (and reset global state) - const auto node_uses = uses_; - uses_.clear(); - - std::stringstream top_level_str; - - // Hoist implicit definitions - for (auto use : node_uses) { - const auto def = use->definition(); - if (def && visited_.find(def) == visited_.end()) { - margin_ = "~ "; - top_level_str << gen(def, true); - margin_ = ""; - } - } - - top_level_str << node_str.str(); - return top_level_str.str(); - } else { - return node_str.str(); - } -} - -std::string IrPrinter::use(const kir::Val* val) { - if (val != nullptr) { - uses_.insert(val); - } - return gen(val); -} - -void IrPrinter::startBlock() { - ++indent_level_; -} - -void IrPrinter::endBlock() { - TORCH_CHECK(indent_level_ > 0); - --indent_level_; -} - -void IrPrinter::handleBlock(const kir::Scope& scope) { - // Save the uses of the parent scope - decltype(uses_) outer_uses; - std::swap(uses_, outer_uses); - - startBlock(); - for (auto expr : scope.exprs()) { - ir_str_ << gen(expr, true); - } - endBlock(); - - // Restore parent's uses - std::swap(uses_, outer_uses); -} - -void IrPrinter::visit(const kir::Bool* node) { - if (node->isConst()) { - ir_str_ << boolLiteral(*node->value()); - } else { - ir_str_ << varName(node, "b"); - } -} - -void IrPrinter::visit(const kir::Double* node) { - if (node->isConst()) { - const int digits = std::numeric_limits::max_digits10; - ir_str_ << "double(" << std::setprecision(digits) << *node->value() << ")"; - } else { - ir_str_ << varName(node, "d"); - } -} - -void IrPrinter::visit(const kir::Int* node) { - if (node->isConst()) { - ir_str_ << *node->value(); - } else { - ir_str_ << varName(node, "i"); - } -} - -void IrPrinter::visit(const kir::NamedScalar* node) { - ir_str_ << node->name(); -} - -void IrPrinter::visit(const kir::Predicate* node) { - switch (node->predicate_type()) { - case PredicateType::Inline: { - ir_str_ << "Inline"; - break; - } - case PredicateType::Manual: { - ir_str_ << node->value(); - break; - } - case PredicateType::Misaligned: { - ir_str_ << "Misaligned"; - break; - } - case PredicateType::Padding: { - ir_str_ << "Padding"; - break; - } - case PredicateType::Shift: { - ir_str_ << "Shift"; - break; - } - case PredicateType::Unswitch: { - ir_str_ << "Unswitch"; - break; - } - case PredicateType::Vectorize: { - ir_str_ << "Vectorize"; - break; - } - default: - break; - } -} - -void IrPrinter::visit(const kir::TensorIndex* node) { - ir_str_ << gen(node->view()) << "["; - for (auto index : node->indices()) { - ir_str_ << use(index); - if (index != node->indices().back()) { - ir_str_ << ", "; - } - } - ir_str_ << "]"; -} - -void IrPrinter::visit(const kir::IterDomain* node) { - ir_str_ << varName(node, "id") << "["; - if (node->isRFactorProduct()) { - ir_str_ << "rfactor."; - } - ir_str_ << node->parallelType() << "." << node->iterType() << "(" - << use(node->start()) << " .. " << use(node->extent()) << ")]"; -} - -void IrPrinter::visit(const kir::TensorDomain*) { - // TODO(kir): print Tensor shapes? - ir_str_ << "kir::TensorDomain"; -} - -void IrPrinter::visit(const kir::TensorView* node) { - // TODO(kir): print memory type too? - ir_str_ << varName(node, "T"); -} - -void IrPrinter::visit(const kir::UnaryOp* node) { - indent() << gen(node->out()) << " = "; - - auto op_type = node->operation(); - - if (auto op = inline_op_str(op_type)) { - if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { - ir_str_ << stringifyBooleanOp(op_type) << gen(node->in()); - } else { - ir_str_ << *op << gen(node->in()); - } - } else { - if (op_type == UnaryOpType::Cast) { - const auto cast_str = - cast_func_str({node->in()->dtype(), node->out()->dtype()}); - ir_str_ << cast_str.value(); - } else { - ir_str_ << op_type; - if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) { - ir_str_ << "f"; - } - } - - if (op_type == UnaryOpType::RandLike) { - ir_str_ << "(RND"; - } else { - ir_str_ << "("; - ir_str_ << use(node->in()); - } - ir_str_ << ")"; - } - - ir_str_ << "\n"; -} - -void IrPrinter::visit(const kir::BinaryOp* node) { - indent() << gen(node->out()) << " = "; - - const auto op_type = node->operation(); - const auto lhs = use(node->lhs()); - const auto rhs = use(node->rhs()); - - if (auto op = inline_op_str(op_type)) { - ir_str_ << lhs << " "; - if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { - ir_str_ << stringifyBooleanOp(op_type); - } else { - ir_str_ << *op; - } - ir_str_ << " " << rhs; - } else { - ir_str_ << op_type; - if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) { - ir_str_ << "f"; - } - ir_str_ << "(" << lhs << ", " << rhs << ")"; - } - - ir_str_ << "\n"; -} - -void IrPrinter::visit(const kir::TernaryOp* node) { - indent() << gen(node->out()) << " = " << node->operation() << "(" - << use(node->in1()) << ", " << use(node->in2()) << ", " - << use(node->in3()) << ")\n"; -} - -void IrPrinter::visit(const kir::ReductionOp* node) { - indent() << gen(node->out()) << " = " - << "REDUCTION(op='" << node->operation() << "'" - << ", in=" << use(node->in()) << ", init=" << use(node->init()) - << ", pred=" << use(node->predicate()) << ")\n"; -} - -void IrPrinter::visit(const kir::WelfordOp* node) { - indent() << gen(node->outVar()) << "," << gen(node->outAvg()) << "," - << gen(node->outN()) << " = " - << "Welford( inAvg=" << use(node->inAvg()); - if (!node->inN()->isOneInt()) { - indent() << " inVar=" << use(node->inVar()); - } - indent() << " inN=" << use(node->inN()); - if (!node->initN()->isZeroInt()) { - indent() << ", initVar=" << use(node->initVar()) - << " initAvg=" << use(node->initAvg()) - << " initN=" << use(node->initN()); - } - indent() << ", pred=" << use(node->predicate()) << ")\n"; -} - -void IrPrinter::visit(const kir::GridReduction* node) { - const auto* reduction_op = node->reduction_op(); - indent() << gen(reduction_op->out()) << " = " - << "GRID_REDUCTION(op='" << reduction_op->operation() << "'" - << ", in=" << use(reduction_op->in()) - << ", init=" << use(reduction_op->init()) - << ", pred=" << use(reduction_op->predicate()) << ")\n"; - indent() << kTab << kTab - << ".reduction_buffer=" << use(node->reduction_buffer()->buffer()) - << "\n"; - indent() << kTab << kTab - << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n"; - indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n"; -} - -void IrPrinter::visit(const kir::GridWelford* node) { - const auto* welford_op = node->welford_op(); - indent() << gen(welford_op->outVar()) << "," << gen(welford_op->outAvg()) - << "," << gen(welford_op->outN()) << " = " - << "GRID_WELFORD(" - << "inAvg=" << use(welford_op->inAvg()); - if (!welford_op->inN()->isOneInt()) { - indent() << ", inVar=" << use(welford_op->inVar()); - } - indent() << ", inN=" << use(welford_op->inN()); - if (!welford_op->initN()->isZeroInt()) { - indent() << ", initVar=" << use(welford_op->initVar()) - << " initAvg=" << use(welford_op->initAvg()) - << " initN=" << use(welford_op->initN()); - } - indent() << ", pred=" << use(welford_op->predicate()) << ")\n"; - indent() << kTab << kTab - << ".var_buffer=" << use(node->var_buffer()->buffer()) - << ".avg_buffer=" << use(node->avg_buffer()->buffer()) - << ".n_buffer=" << use(node->N_buffer()->buffer()) << "\n"; - indent() << kTab << kTab - << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n"; - indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n"; -} - -void IrPrinter::visit(const kir::BroadcastOp* node) { - indent() << gen(node->out()) << " = BROADCAST(" << use(node->in()) << ")\n"; -} - -void IrPrinter::visit(const kir::ForLoop* node) { - indent() << "FOR " << gen(node->index()) << " in " << gen(node->iter_domain()) - << ":\n"; - handleBlock(node->body()); -} - -void IrPrinter::visit(const kir::IfThenElse* node) { - indent() << "IF " << use(node->predicate()) << ":\n"; - handleBlock(node->thenBody()); - if (node->hasElse()) { - indent() << "ELSE:\n"; - handleBlock(node->elseBody()); - } -} - -void IrPrinter::visit(const kir::Allocate* node) { - indent() << gen(node->buffer()) << " = ALLOCATE(" - << "mem_type=" << node->memoryType() << ", " - << "size=" << use(node->size()) << ", " - << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n"; - if (node->alias() != nullptr) { - indent() << kTab << kTab << ".alias=" << gen(node->alias()->buffer()) - << "\n"; - } -} - -void IrPrinter::visit(const kir::Sync* node) { - indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync()) - << ")\n"; -} - -void IrPrinter::visit(const kir::InitMagicZero* node) { - indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; -} - -void IrPrinter::visit(const kir::UpdateMagicZero* node) { - indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; -} - -std::string toString(const kir::Node* stmt, bool implicit_definitions) { - std::stringstream ss; - IrPrinter ir_printer(ss, implicit_definitions); - ir_printer.printNode(stmt); - return ss.str(); -} - -std::string toString( - const std::vector& exprs, - bool implicit_definitions) { - std::stringstream ss; - IrPrinter ir_printer(ss, implicit_definitions); - for (auto expr : exprs) { - ir_printer.printNode(expr); - } - return ss.str(); -} - -} // namespace kir -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h deleted file mode 100644 index 115901a031a..00000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h +++ /dev/null @@ -1,129 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include -#include -#include -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { -namespace kir { - -//! Define pretty printing functions for Kernel IR nodes -//! -//! This class is intended for debug printing, so it attempts -//! to handle invalid IR states as much as possible. -//! -//! implicit_definition_ = true will recurisvely print the definition of all -//! inputs to an expression if they haven't been printed. -class TORCH_CUDA_CU_API IrPrinter : private kir::IrVisitor { - static constexpr char const* kTab = " "; - - public: - //! Constructs a new IrPrinter which outputs to the specified stream - explicit IrPrinter(std::ostream& os, bool implicit_definition = true) - : os_(os), implicit_definition_(implicit_definition) {} - - //! Print a single Kernel IR node - void printNode(const kir::Node* node); - - //! Print a complete Kernel definition - void printKernel(const Kernel* kernel); - - private: - // Generates a string representation of an IR node - // - // If `top_level` is true, all the value uses are tracked and - // their definitions are implicitly printed before the node itself - // - std::string gen(const kir::Node* node, bool top_level = false); - - // Generate a string representation of an used value - // (this helps automatically tracking the value uses) - std::string use(const kir::Val* val); - - std::ostream& indent(); - - void startBlock(); - void endBlock(); - void handleBlock(const kir::Scope& scope); - - void visit(const kir::Bool*) final; - void visit(const kir::Double*) final; - void visit(const kir::Int*) final; - void visit(const kir::NamedScalar*) final; - void visit(const kir::Predicate*) final; - - void visit(const kir::TensorIndex*) final; - void visit(const kir::IterDomain*) final; - void visit(const kir::TensorDomain*) final; - void visit(const kir::TensorView*) final; - - void visit(const kir::UnaryOp*) final; - void visit(const kir::BinaryOp*) final; - void visit(const kir::TernaryOp*) final; - void visit(const kir::ReductionOp*) final; - void visit(const kir::WelfordOp*) final; - void visit(const kir::BroadcastOp*) final; - - void visit(const kir::GridReduction*) final; - void visit(const kir::GridWelford*) final; - void visit(const kir::ForLoop*) final; - void visit(const kir::IfThenElse*) final; - void visit(const kir::Allocate*) final; - void visit(const kir::Sync*) final; - void visit(const kir::InitMagicZero*) final; - void visit(const kir::UpdateMagicZero*) final; - - private: - std::ostream& os_; - - // Current indentation level - int indent_level_ = 0; - - // Internal IR generation stream - std::stringstream ir_str_; - - // Tracks the set of nodes which have been printed - std::unordered_set visited_; - - // Optional left margin printed after the indentation - const char* margin_ = ""; - - // The set of values used by the current top-level IR node - std::unordered_set uses_; - - // If the definition of all inputs to an expression haven't been printed - // already implicit_definition_ = true will print them before printing the - // requested node. - bool implicit_definition_ = true; -}; - -//! Returns the string representation of a Kernel IR node. If the definition of -//! all inputs to an expression haven't been printed already -//! implicit_definition_ = true will print them before printing the requested -//! node. -TORCH_CUDA_CU_API std::string toString( - const kir::Node* stmt, - bool implicit_definitions = true); - -//! Returns the string representation of a vector of kir::Expr, convenient -//! debugm echanism during lowering. If the definition of all inputs to an -//! expression haven't been printed already implicit_definition_ = true will -//! print them before printing the requested node. -TORCH_CUDA_CU_API std::string toString( - const std::vector& exprs, - bool implicit_definitions = true); - -} // namespace kir -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 036eee58206..21eb6e02fb8 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -6,18 +6,19 @@ #include #include #include -#include #include #include +#include #include +#include #include #include #include #include #include #include +#include #include -#include #include #include #include @@ -33,152 +34,15 @@ namespace jit { namespace fuser { namespace cuda { -// TODO(kir): revisit this thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT namespace { -// Going to generate a map of tensor view root domain extents to reduce the -// number used during lowering. For example if we have: -// -// T2[i0, i1] = T1[i0, i1] + T2[i2, i3] -// -// We know it would be safe to use: -// -// T2[i0, i1] = T1[i0, i1] + T2[i0, i1] -// -// And that way we don't generate T2.size[0] and T2.size[1], instead we will -// reuse T1.size[0] and T1.size[1] -// This is important when doing CSE as T2 and T1 would otherwise look like -// they're using different values, even though we know they're the same -// -// There's some duplicate logic here that's in computeAt map, but it's not so -// concice there to pull out. May want to consider making this mapping its own -// class especially as it may be useful during scheduling. -std::unordered_map getSimplificationMap(Fusion* fusion) { - std::list> disjoint_root_sets; - std::unordered_map*> - id_to_disjoint_root_set; - - auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set]( - IterDomain* id0, IterDomain* id1) { - if (id0->isBroadcast() || id1->isBroadcast()) { - return; - } - - auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0); - auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1); - bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end(); - bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end(); - - if (set_0_found && set_1_found) { - if (disjoint_set_0_it->second == disjoint_set_1_it->second) { - return; - } - // merge second disjoint set into first - auto* set_0 = disjoint_set_0_it->second; - auto* set_1 = disjoint_set_1_it->second; - for (auto id : *set_1) { - set_0->emplace(id); - id_to_disjoint_root_set[id] = set_0; - } - // remove second set from disjoint_root_sets - disjoint_root_sets.erase(std::find( - disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1)); - } else if (set_0_found || set_1_found) { - auto existing_set = - set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second; - auto to_add_id = set_0_found ? id1 : id0; - existing_set->emplace(to_add_id); - id_to_disjoint_root_set[to_add_id] = existing_set; - // add entry into existing set - } else { - // create new set entry - disjoint_root_sets.emplace_back(std::unordered_set()); - auto* new_set = &disjoint_root_sets.back(); - new_set->emplace(id0); - new_set->emplace(id1); - id_to_disjoint_root_set[id0] = new_set; - id_to_disjoint_root_set[id1] = new_set; - } - }; - - auto fusion_vals = fusion->usedMathVals(); - for (auto producer_tv : ir_utils::filterByType(fusion_vals)) { - auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); - for (auto consumer_tv : consumer_tvs) { - auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); - auto c2p_root_map = pairwise_map.mapConsumerToProducer( - consumer_tv->domain(), producer_tv->domain()); - for (auto entry : c2p_root_map) { - auto c_id = entry.first; - auto p_id = entry.second; - map_root_ids(p_id, c_id); - } - } - } - - // Map each set to an input ID (if it exists) that has the smallest ->name() - // entry value - std::unordered_map*, IterDomain*> - set_to_input_id; - - // Loop over the root domains, of the inputs to the fusion. Pick an input ID - // to use as the representative ID of the collected sets. Only consider inputs - // as those are the ones that map to values like "T0.size[1]". They are he - // ID's that propagated their extents into the problem. We could also check - // the outputs as we do have C++ examples of using output dimensions for the - // problem size instead of inputs. However, we don't do anything where we can - // translate to those kinds of kernels integrated into PyTorch. - for (auto input_tv : ir_utils::filterByType(fusion->inputs())) { - for (auto id : - TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) { - auto id_set_it = id_to_disjoint_root_set.find(id); - if (id_set_it == id_to_disjoint_root_set.end()) { - continue; - } - auto* id_set = id_set_it->second; - if (set_to_input_id.find(id_set) == set_to_input_id.end()) { - set_to_input_id[id_set] = id; - } else { - auto input_id_of_set = set_to_input_id.at(id_set); - // Swap id's if new name is less than previously set - bool swap_ids = id->name() < input_id_of_set->name(); - // If new id is a const scalar but previously was'nt use the const - // scalar - swap_ids = swap_ids || - (id->extent()->isConstScalar() && - !input_id_of_set->extent()->isConstScalar()); - // If previous scalar was const and new isn't, don't swap - swap_ids = swap_ids && - !(input_id_of_set->extent()->isConstScalar() && - !id->extent()->isConstScalar()); - - if (swap_ids) { - set_to_input_id[id_set] = id; - } - } - } - } - - // Finally make map from ID extents to the representitive ID extent. - std::unordered_map extent_to_min_input_id_extent; - for (auto entry : set_to_input_id) { - auto* set = entry.first; - auto input_id = entry.second; - for (auto id : *set) { - extent_to_min_input_id_extent[id->extent()] = input_id->extent(); - } - } - return extent_to_min_input_id_extent; -} - -class KIRCleaner : public kir::MutableIrVisitor { +class KIRCleaner : public OptOutDispatch { public: //! Remove nop IR nodes - static std::vector cleanUp( - const std::vector& loop_nests) { + static std::vector cleanUp(const std::vector& loop_nests) { KIRCleaner cleaner; - std::vector out_loop_nests; + std::vector out_loop_nests; for (auto loop_nest : loop_nests) { cleaner.handle(loop_nest); // No need to keep the loop nest if it's determined to be nop @@ -190,16 +54,17 @@ class KIRCleaner : public kir::MutableIrVisitor { } private: - void handle(kir::Expr* expr) { + using OptOutDispatch::handle; + void handle(Expr* expr) final { if (expr->isA() || expr->isA()) { - expr->accept(this); + OptOutDispatch::handle(expr); } else { // Any non-scoping expr is not considered nop is_nop_ = false; } } - void visit(kir::ForLoop* fl) final { + void handle(kir::ForLoop* fl) final { auto exprs = fl->body().exprs(); fl->body().clear(); for (auto expr : exprs) { @@ -213,7 +78,7 @@ class KIRCleaner : public kir::MutableIrVisitor { is_nop_ = fl->body().empty(); } - void visit(kir::IfThenElse* ite) final { + void handle(kir::IfThenElse* ite) final { const auto conditional = ite->predicate()->value(); // Visit the then block @@ -248,9 +113,8 @@ class KIRCleaner : public kir::MutableIrVisitor { // conditional and move the exprs in the else block to the then // block. if (then_nop && !else_nop) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - kir::Bool* pred = ite->predicate()->value(); - kir::Bool* not_pred = ir_builder.notExpr(pred)->as(); + Bool* pred = ite->predicate()->value(); + Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as(); ite->predicate()->setValue(not_pred); for (auto expr : ite->elseBody().exprs()) { ite->thenBody().push_back(expr); @@ -269,84 +133,6 @@ class KIRCleaner : public kir::MutableIrVisitor { } // namespace -void GpuLower::replaceSymbolicSizes() { - FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes"); - - kir::IrBuilder ir_builder(kernel()); - - // Grab inputs and outputs - std::vector inputs_and_outputs; - for (auto val : fusion_->inputs()) { - if (ir_utils::isTV(val)) { - inputs_and_outputs.push_back(val->as()); - } - } - // Symbolic size is necessary for outputs if there are no inputs. - // Otherwise infer output sizes from the inputs via expression evaluation. - if (fusion_->inputs().empty()) { - for (auto val : fusion_->outputs()) { - if (ir_utils::isTV(val)) { - inputs_and_outputs.push_back(val->as()); - } - } - } - - // Generate map for all tensorview root domain values to map them to symbolic - // values. i.e. T0->getRootDomain()[0] would map to a named scalar - // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir. - for (TensorView* tv : inputs_and_outputs) { - // Replace the domain with one based on Ti.size[j] - const std::vector& root_td = tv->getRootDomain(); - - size_t dim = 0; - for (auto id : root_td) { - const Val* orig_size = id->extent(); - - // Output sizes could have reduction axes, which isn't what gets output. - // NOLINTNEXTLINE(bugprone-branch-clone) - if (id->isReduction() || - (id->getIterType() == IterType::BroadcastWithoutStride)) { - continue; - } else if ( - id->isRFactorProduct() || - // NOLINTNEXTLINE(bugprone-branch-clone) - (id->getIterType() == IterType::BroadcastWithStride) || - orig_size->isConstScalar()) { - dim++; - continue; - } - - // TODO(kir): consider a different implementation which doesn't - // hijack the kir_val_map_ - // Currently turn off this part for inputs of segmented fusion, - // since FusionKernelRuntime will provide these as integer inputs - if (kir_val_map_.find(orig_size) == kir_val_map_.end() && - !orig_size->isFusionInput() && !orig_size->isConstScalar()) { - std::stringstream ss; - ss << "T" << tv->name() << ".size[" << dim++ << "]"; - kir_val_map_[orig_size] = ir_builder.create( - ss.str(), orig_size->getDataType().value()); - } else { - dim++; - } - } - } - - // Use a minimal number of sizes from provided tensors. - auto extent_simplification_map = getSimplificationMap(fusion_); - for (auto extent_entry : extent_simplification_map) { - auto orig_extent = extent_entry.first; - auto simplified_extent = extent_entry.second; - if (kir_val_map_.count(orig_extent)) { - if (kir_val_map_.count(simplified_extent)) { - kir_val_map_[orig_extent] = kir_val_map_[simplified_extent]; - } else { - kir_val_map_[orig_extent] = lowerValue(simplified_extent); - } - } - } -} - void GpuLower::collectPaddedParallelDims() { ExpressionEvaluator ee(fusion_); bool can_be_single_warp = true; @@ -398,14 +184,12 @@ void GpuLower::collectPaddedParallelDims() { } } -void GpuLower::lower() { +void GpuLower::lower(Fusion* fusion) { FUSER_PERF_SCOPE("GpuLower::lower"); - - TORCH_INTERNAL_ASSERT(fusion_ != nullptr); + TORCH_INTERNAL_ASSERT(fusion != nullptr); TORCH_INTERNAL_ASSERT( active_gpu_lower == nullptr, "Nested lowering passes are not supported"); - // TODO(kir): revisit this struct LowerGuard { LowerGuard(GpuLower* gpu_lower) { active_gpu_lower = gpu_lower; @@ -414,17 +198,21 @@ void GpuLower::lower() { active_gpu_lower = nullptr; } } lower_guard(this); + // Copy fusion into a new kernel for processing + kernel_ = std::make_unique(fusion); + // Alias the fusion kernel caries around as a view of itself. + fusion_ = kernel_.get(); FusionGuard fg(fusion_); - - // Start with a fresh kernel - kernel_ = std::make_unique(); - // prepare for lowering validateIr(fusion_); - replaceSymbolicSizes(); + collectPaddedParallelDims(); - trivial_reduction_info_.build(fusion_, this); + + replaceSymbolicSizes(fusion_); + + trivial_reduction_info_.build(fusion_); + trivialReductionReplacement(fusion_, trivialReductionInfo()); // In the future we may directly use this map, but for now it will propagate // and validate (to some extent) the parallelization strategy. @@ -447,9 +235,12 @@ void GpuLower::lower() { parallelDimensionMap().build(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) { - std::cout << parallelDimensionMap().toString(); + std::cout << "Parallel dimension map:" << std::endl; + std::cout << parallel_dimension_map_.toString() << std::endl; } + concretized_broadcast_domains_.build(fusion_); + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); @@ -469,61 +260,67 @@ void GpuLower::lower() { nonDivisibleSplitInfo().build(fusion_); - // Set the kernel inputs & outputs - for (auto input : fusion_->inputs()) { - kernel_->addInput(GpuLower::lowerValue(input)); - } - - for (auto output : fusion_->outputs()) { - kernel_->addOutput(GpuLower::lowerValue(output)); - } + doubleBufferInfo().build(fusion_); // Run our passes keeping the lowered expressions and forwarding // them // Reorder expressions for loop-nest generation respecting computeAt // relationships - auto sorted_exprs = reorderExprsForComputeAt(); + const auto exprs_sorted = reorderExprsForComputeAt(); // Generate loop-nests and place each expression at its // corresponding loop - const auto lowered_exprs = LoopNestGenerator::loweredExprs(sorted_exprs); + const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted); + + // Replace trivial reductions, Transpose, Shift, Gather, and View ops with + // unary ops since they're not separately processed in lowering. + const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered); // Insert allocations - const auto alloced_exprs = insertAllocations(lowered_exprs); + const auto exprs_alloced = insertAllocations(exprs_unary_replaced); // Insert read after write smem syncs - const auto raw_sync_exprs = insertRawThreadSynchronization(alloced_exprs); + const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced); // Reuse memory locations - const auto reuse_mem_exprs = reuseMemoryAllocations(raw_sync_exprs); + const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync); - // Inserts predicates after this, need to be careful in later passes when - // inserting in loop nest structure as insertions could be on if then else - // instead of directly on a for loop - const auto unrolled_loops = UnrollPass::runPass(fusion_, reuse_mem_exprs); + // Insert SyncThreads at end of for-loop to avoid WAR race condition + const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem); - const auto unrolled_mv_loops = - processMisalignedVectorization(fusion_, unrolled_loops); + const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync); - // Insert SyncThreads at end of for-loop to avoid WAR race condition - const auto war_sync_exprs = insertWarThreadSynchronization(unrolled_mv_loops); + // This pass inserts predicates as well as branches in the code. Up until now + // the code is explicitly single shot for loop based. Need to be careful in + // later passes when doing any kind of insertions in loop nest structure as + // insertions could be on if then or else instead of directly on a for loop. + const auto exprs_unrolled_loops = + UnrollPass::runPass(fusion_, exprs_double_buffered); - const auto indexed_loops = IndexLowering::getIndexedExprs(war_sync_exprs); + const auto exprs_unrolled_mv_loops = + processMisalignedVectorization(exprs_unrolled_loops); - const auto exprs_with_fused_broadcast = fuseWarpReduce(indexed_loops); + const auto exprs_indexed_loops = + IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops); - const auto conditional_loops = - generateConditionalFromPredicate(fusion_, exprs_with_fused_broadcast); + // TODO: It seems this type of optimization would be far easier to implement + // on fusion ir than kernel ir. We should likely refactor this to at least run + // before allocation insertion. + const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops); + + const auto exprs_conditional_loops = + generateConditionalFromPredicate(exprs_with_fused_broadcast); // Insert fake zero updates to make sure nvrtc doesn't blow out register use // on index and predicate reuse - const auto register_adjusted = insertMagicZero(conditional_loops); + const auto exprs_register_adjusted = insertMagicZero(exprs_conditional_loops); - const auto cleaned_up_loops = KIRCleaner::cleanUp(register_adjusted); + const auto exprs_cleaned_up_loops = + KIRCleaner::cleanUp(exprs_register_adjusted); // We now have the lowered expressions, finalize the kernel IR - kernel_->finalize(cleaned_up_loops); + kernel_->finalize(exprs_cleaned_up_loops); } kir::Kernel* GpuLower::kernel() const { @@ -531,213 +328,9 @@ kir::Kernel* GpuLower::kernel() const { return kernel_.get(); } -// Maps Fusion IR nodes to the Kernel IR counterparts -class GpuLower::KernelIrMapper : private OptInConstDispatch { - public: - explicit KernelIrMapper(GpuLower* gpu_lower) - : gpu_lower_(gpu_lower), ir_builder_(gpu_lower->kernel()) {} - - kir::Val* lowerValue(const Val* value) { - const auto it = gpu_lower_->kir_val_map_.find(value); - if (it != gpu_lower_->kir_val_map_.end()) { - return it->second; - } else { - handle(value); - const auto kir_value = gpu_lower_->kir_val_map_[value]; - TORCH_CHECK(kir_value != nullptr); - - // Lower the value definition, if any - if (value->isScalar()) { - if (auto def = value->definition()) { - const auto kir_def = lowerExpr(def); - TORCH_INTERNAL_ASSERT(kir_value->definition() == kir_def); - } - } - - return kir_value; - } - } - - kir::Expr* lowerExpr(const Expr* expr) { - const auto it = gpu_lower_->kir_expr_map_.find(expr); - if (it != gpu_lower_->kir_expr_map_.end()) { - return it->second; - } else { - handle(expr); - const auto lowered_node = gpu_lower_->kir_expr_map_[expr]; - TORCH_CHECK(lowered_node != nullptr); - return lowered_node; - } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - } - - private: - void handle(const Statement* node) final { - OptInConstDispatch::handle(node); - } - - void handle(const Val* node) final { - OptInConstDispatch::handle(node); - } - - void handle(const Expr* node) final { - OptInConstDispatch::handle(node); - } - - void handle(const TensorDomain* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const IterDomain* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const TensorView* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const Bool* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const Double* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const Int* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const NamedScalar* node) final { - const auto lowered_node = ir_builder_.create( - node->name(), node->getDataType().value()); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const UnaryOp* node) final { - const auto lowered_node = ir_builder_.create( - node->getUnaryOpType(), - lowerValue(node->out()), - lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const BinaryOp* node) final { - const auto lowered_node = ir_builder_.create( - node->getBinaryOpType(), - lowerValue(node->out()), - lowerValue(node->lhs()), - lowerValue(node->rhs())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const TernaryOp* node) final { - const auto lowered_node = ir_builder_.create( - node->getTernaryOpType(), - lowerValue(node->out()), - lowerValue(node->in1()), - lowerValue(node->in2()), - lowerValue(node->in3())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const ReductionOp* node) final { - auto out_tv = node->out()->as(); - // If trivial reduction operation lower to set operation. - if (std::all_of( - out_tv->domain()->domain().begin(), - out_tv->domain()->domain().end(), - [&](IterDomain* id) { - // If id is a reduction axis, is it a trivial reduction? - if (id->isReduction()) { - return gpu_lower_->trivialReductionInfo().isDerived(id); - } else { - return true; - } - })) { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK( - gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - return; - } - - const auto lowered_node = ir_builder_.create( - node->getReductionOpType(), - lowerValue(node->init()), - lowerValue(node->out()), - lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const WelfordOp* node) final { - auto lowerOptional = [&](Val* v) { return v ? lowerValue(v) : nullptr; }; - const auto lowered_node = ir_builder_.create( - lowerValue(node->outVar()), - lowerValue(node->outAvg()), - lowerValue(node->outN()), - lowerValue(node->initVar()), - lowerValue(node->initAvg()), - lowerValue(node->initN()), - lowerOptional(node->inVar()), - lowerValue(node->inAvg()), - lowerValue(node->inN())); - - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const BroadcastOp* node) final { - const auto lowered_node = ir_builder_.create( - lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const TransposeOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const ShiftOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const GatherOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const ViewOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - private: - GpuLower* gpu_lower_ = nullptr; - kir::IrBuilder ir_builder_; -}; - -kir::Val* GpuLower::lowerValue(const Val* val) { - KernelIrMapper kir_mapper(this); - return kir_mapper.lowerValue(val); -} - -kir::Expr* GpuLower::lowerExpr(const Expr* expr) { - KernelIrMapper kir_mapper(this); - return kir_mapper.lowerExpr(expr); -} - GpuLower* GpuLower::current() { + TORCH_INTERNAL_ASSERT( + active_gpu_lower != nullptr, "No active GpuLower available"); return active_gpu_lower; } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index b807bb4d480..b97c6ac1837 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -1,14 +1,17 @@ #pragma once -#include +#include #include #include #include #include #include +#include #include #include +#include +#include #include #include #include @@ -29,29 +32,27 @@ namespace cuda { // container for this information that we can reuse. Would be nice to generate // such a structure and propagate it through lowering. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API GpuLower { +class TORCH_CUDA_CU_API GpuLower : public NonCopyable { class KernelIrMapper; public: - GpuLower() = default; + GpuLower() = delete; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit GpuLower(Fusion* fusion) : fusion_(fusion) { - lower(); + explicit GpuLower(Fusion* fusion) { + lower(fusion); } kir::Kernel* kernel() const; - //! Converts a Fusion IR value into the Kernel IR equivalent - kir::Val* lowerValue(const Val* val); - - //! Converts a Fusion IR expression into the Kernel IR equivalent - kir::Expr* lowerExpr(const Expr* expr); - //! Returns the currently active lowering object //! (or nullptr if no lowering is in progress) static GpuLower* current(); + ConcretizedBroadcastDomains& concretizedBroadcastDomains() { + return concretized_broadcast_domains_; + } + const ThreadPredicateMap& threadPredMap() const { return thread_pred_map_; } @@ -68,7 +69,7 @@ class TORCH_CUDA_CU_API GpuLower { return ca_parallel_map_; } - const auto& trivialReductionInfo() const { + const TrivialReductionInfo& trivialReductionInfo() const { return trivial_reduction_info_; } @@ -120,16 +121,12 @@ class TORCH_CUDA_CU_API GpuLower { return non_divisible_split_info_; } - private: - void lower(); + DoubleBufferInfo& doubleBufferInfo() { + return double_buffer_info_; + } - // TensorViews are all based on symbolic sizes. When we first initialize them - // we don't know if they're inputs or outputs which would mean that they have - // runtime shapes. Intermediate tensors (those not going to global memory) do - // not have this information. Since we need to have the correct information in - // the kernel being fetched for shapes, we want to replace input and output - // tensors to reference the runtime structure containing sizes. - void replaceSymbolicSizes(); + private: + void lower(Fusion* fusion); // Goes through the parallelized iterdomains of the used TVs and find // the parallel dimensions that need to be padded to a multiples of @@ -140,11 +137,8 @@ class TORCH_CUDA_CU_API GpuLower { // Lowered Kernel IR std::unique_ptr kernel_; - // Fusion IR node to Kernel IR node mapping - std::unordered_map kir_val_map_; - std::unordered_map kir_expr_map_; - // Some stateful information during lowering + ConcretizedBroadcastDomains concretized_broadcast_domains_; ThreadPredicateMap thread_pred_map_; PredicateElimination pred_elimination_; ComputeAtMap ca_loop_map_; @@ -157,6 +151,7 @@ class TORCH_CUDA_CU_API GpuLower { ParallelDimensionMap parallel_dimension_map_; PartialSplitMap partial_split_map_; NonDivisibleSplitInfo non_divisible_split_info_; + DoubleBufferInfo double_buffer_info_; Fusion* fusion_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 80e2e58c9cf..17a2db069d8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -1,10 +1,10 @@ #include #include +#include #include #include #include -#include #include #include @@ -22,40 +22,42 @@ namespace { //! Get string representation of Allocate size for symbolic comparison //! //! TODO: Some expr simplifications could also be helpful -class SymbolicSizePrinter : private kir::IrVisitor { +class SymbolicSizePrinter : private OptOutConstDispatch { public: static std::string printSize(const kir::Allocate* allocate) { SymbolicSizePrinter printer; - allocate->size()->accept(&printer); + printer.handle(allocate->size()); return printer.os_.str(); } private: - void visit(const kir::Int* node) final { + using OptOutConstDispatch::handle; + + void handle(const Int* node) final { if (auto def = node->definition()) { - def->accept(this); + OptOutConstDispatch::handle(def); } else if (node->isConst()) { os_ << *node->value(); } else { - os_ << "ki" << node->id(); + os_ << "ki" << node->name(); } } - void visit(const kir::NamedScalar* named_scalar) final { + void handle(const NamedScalar* named_scalar) final { os_ << "@" << named_scalar->name(); } - void visit(const kir::UnaryOp* unary_op) final { - os_ << unary_op->operation() << "("; - unary_op->in()->accept(this); + void handle(const UnaryOp* unary_op) final { + os_ << unary_op->getUnaryOpType() << "("; + OptOutConstDispatch::handle(unary_op); os_ << ")"; } - void visit(const kir::BinaryOp* binary_op) final { - os_ << binary_op->operation() << "("; - binary_op->lhs()->accept(this); + void handle(const BinaryOp* binary_op) final { + os_ << binary_op->getBinaryOpType() << "("; + OptOutConstDispatch::handle(binary_op->lhs()); os_ << ","; - binary_op->rhs()->accept(this); + OptOutConstDispatch::handle(binary_op->rhs()); os_ << ")"; } @@ -74,11 +76,11 @@ class BufferReuseDebugPrinter { DebugLineType line_type = DebugLineType::EXPR; }; - using DebugEntry = std::pair; + using DebugEntry = std::pair; using DebugEntryPtr = std::unique_ptr; public: - BufferReuseDebugPrinter() : ir_printer_(os_, false){}; + BufferReuseDebugPrinter() : ir_printer_(os_){}; std::string dumpDebugInfo() { os_.clear(); @@ -105,7 +107,7 @@ class BufferReuseDebugPrinter { private: friend class BufferUseDefInfo; - void pushBack(int lineno, kir::Expr* expr) { + void pushBack(int lineno, Expr* expr) { makeExprEntry(lineno, expr); } @@ -117,7 +119,7 @@ class BufferReuseDebugPrinter { makeScopeEntry(DebugLineType::END_BLOCK); } - void makeExprEntry(int lineno, kir::Expr* expr) { + void makeExprEntry(int lineno, Expr* expr) { auto debug_entry_ptr = std::make_unique(); debug_entry_ptr->first.lineno = lineno; debug_entry_ptr->second = expr; @@ -134,14 +136,14 @@ class BufferReuseDebugPrinter { debug_info_.emplace_back(std::move(debug_entry_ptr)); } - void handle(const kir::Expr* node) { + void handle(const Expr* node) { if (auto for_loop = dynamic_cast(node)) { handle(for_loop); } else if (auto ite = dynamic_cast(node)) { handle(ite); } else { indent(); - ir_printer_.printNode(node); + ir_printer_.handle(node); } if (auto alloc = dynamic_cast(node)) { printAllocInfo(alloc); @@ -151,9 +153,9 @@ class BufferReuseDebugPrinter { void handle(const kir::ForLoop* node) { indent(); os_ << "FOR "; - ir_printer_.printNode(node->index()); + ir_printer_.handle(node->index()); os_ << " in "; - ir_printer_.printNode(node->iter_domain()); + ir_printer_.handle(node->iter_domain()); os_ << ":\n"; } @@ -186,7 +188,7 @@ class BufferReuseDebugPrinter { private: std::stringstream os_; - kir::IrPrinter ir_printer_; + IrPrinter ir_printer_; int indent_level_ = 0; std::vector debug_info_; @@ -340,7 +342,7 @@ class BufferUseDefInfo { static constexpr long kRegisterSizeThreshold = 1; BufferUseDefInfo( - const std::vector& exprs, + const std::vector& exprs, BufferReuseDebugPrinter* debug_printer = nullptr) : debug_printer_(debug_printer) { if (debug_printer) { @@ -410,7 +412,7 @@ class BufferUseDefInfo { } private: - void handle(kir::Expr* expr) { + void handle(Expr* expr) { current_pos_++; if (debug_printer_) { debug_printer_->pushBack(current_pos_, expr); @@ -426,7 +428,7 @@ class BufferUseDefInfo { } } - void handleScope(const std::vector& exprs) { + void handleScope(const std::vector& exprs) { if (debug_printer_) { debug_printer_->pushScope(); } @@ -460,15 +462,15 @@ class BufferUseDefInfo { return; } - auto kir_tv = dynamic_cast(alloc->buffer()); - if (!kir_tv) { + auto tv = dynamic_cast(alloc->buffer()); + if (!tv) { return; } // Collect the allocate info data // Collect memory type, skip global buffers - auto mem_type = kir_tv->memoryType(); + auto mem_type = tv->getMemoryType(); if (mem_type != MemoryType::Local && mem_type != MemoryType::Shared) { return; } @@ -487,12 +489,12 @@ class BufferUseDefInfo { } } - auto data_type = kir_tv->dtype(); + auto data_type = tv->dtype(); auto size_print = SymbolicSizePrinter::printSize(alloc); // Make sure we don't have conflicting information on record TORCH_INTERNAL_ASSERT(!map_allocate_to_info_.count(alloc)); - TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(kir_tv->name())); + TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(tv->name())); // make AllocationUseDefInfo: auto alloc_info = makeUseDefInfo(); @@ -505,10 +507,10 @@ class BufferUseDefInfo { // record short cuts map_allocate_to_info_[alloc] = alloc_info; - map_tv_to_allocations_[kir_tv->name()] = alloc_info; + map_tv_to_allocations_[tv->name()] = alloc_info; } - void collectScopeUseDefInfo(const std::vector& exprs) { + void collectScopeUseDefInfo(const std::vector& exprs) { // Reset position pointer resetExprCounter(); TORCH_INTERNAL_ASSERT(global_scope_info_ != nullptr); @@ -516,14 +518,14 @@ class BufferUseDefInfo { handleScope(exprs); } - void collectScopeInfo(const std::vector& exprs) { + void collectScopeInfo(const std::vector& exprs) { // Reset position pointer resetExprCounter(); collectScopeInfoWithinLoop(exprs, nullptr); } void collectScopeInfoWithinLoop( - const std::vector& exprs, + const std::vector& exprs, kir::ForLoop* current_loop) { auto loop_info = makeScopeInfo(current_loop); for (auto expr : exprs) { @@ -584,22 +586,20 @@ class BufferUseDefInfo { // Iterate over the inputs and outputs of exprs and update // the liveness info of local buffers if applicaable. - void collectLivenessInfo(const kir::Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + void collectLivenessInfo(const Expr* expr) { + if (!ir_utils::isTvOp(expr)) { return; } - auto out_tv = expr->outputs()[0]->as(); - auto fuser_out_tv = out_tv->fuserTv(); + auto out_tv = expr->outputs()[0]->as(); // Collect all tv's that resolves broadcast in this // expr. The current analysis isn't enough to capture // their liveness range. - for (auto input_tv : - ir_utils::filterByType(expr->inputs())) { + for (auto input_tv : ir_utils::filterByType(expr->inputs())) { auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); if (maybe_alloc_info.has_value()) { - if (isSerialBroadcastResolution(input_tv->fuserTv(), fuser_out_tv)) { + if (isSerialBroadcastResolution(input_tv, out_tv)) { maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_); } else { // Disable inner alias info for this buffer, since line number based @@ -621,8 +621,7 @@ class BufferUseDefInfo { } } } - for (auto output_tv : - ir_utils::filterByType(expr->outputs())) { + for (auto output_tv : ir_utils::filterByType(expr->outputs())) { auto maybe_alloc_info = getMaybeAllocInfoFromTV(output_tv); if (maybe_alloc_info.has_value()) { maybe_alloc_info.value()->inner_live_interval->markWrite(current_pos_); @@ -675,8 +674,7 @@ class BufferUseDefInfo { return nullptr; } - c10::optional getMaybeAllocInfoFromTV( - kir::TensorView* tv) { + c10::optional getMaybeAllocInfoFromTV(TensorView* tv) { auto alloc_it = map_tv_to_allocations_.find(tv->name()); if (alloc_it == map_tv_to_allocations_.end()) { return c10::nullopt; @@ -810,11 +808,11 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { //! Reuse Allocation nodes via pointer aliasing class AllocateReuseModifier { public: - static void modify(const std::vector& exprs) { + static void modify(const std::vector& exprs) { AllocateReuseModifier modifier(exprs); } - static void debugPrint(const std::vector& exprs) { + static void debugPrint(const std::vector& exprs) { BufferReuseDebugPrinter debug_printer; AllocateReuseModifier modifier(exprs, &debug_printer); std::cout << debug_printer.dumpDebugInfo(); @@ -822,7 +820,7 @@ class AllocateReuseModifier { private: AllocateReuseModifier( - const std::vector& exprs, + const std::vector& exprs, BufferReuseDebugPrinter* debug_printer_ = nullptr) : buffer_info_(exprs, debug_printer_) { // Perform in-place sharing first and then outer liveness @@ -941,7 +939,7 @@ class AllocateReuseModifier { return false; } - void handle(kir::Expr* expr) { + void handle(Expr* expr) { if (auto ite = dynamic_cast(expr)) { handle(ite); } else if (auto for_loop = dynamic_cast(expr)) { @@ -961,7 +959,7 @@ class AllocateReuseModifier { "lower_alias_memory: IfThenElse before unrolling is not yet supported"); } - void handleScope(const std::vector& exprs) { + void handleScope(const std::vector& exprs) { current_visible_buffer_stack_.emplace_back( std::make_unique()); for (auto expr : exprs) { @@ -990,10 +988,8 @@ class AllocateReuseModifier { } // Assume inputs are TV allocations, which should have been checked // before reaching this point. - auto this_tv = - alloc_info->alloc_expr->buffer()->as()->fuserTv(); - auto reuse_tv = - to_reuse->alloc_expr->buffer()->as()->fuserTv(); + auto this_tv = alloc_info->alloc_expr->buffer()->as(); + auto reuse_tv = to_reuse->alloc_expr->buffer()->as(); // Check the values in between the two buffers. auto vals_between_this_and_reuse = @@ -1068,8 +1064,8 @@ class AllocateReuseModifier { } bool allocationDomainsIndexMapped( - std::vector& alloc_domains, - std::vector& reuse_domains) { + std::vector& alloc_domains, + std::vector& reuse_domains) { // Require that the allocated domains are exactly mapped. if (alloc_domains.size() != reuse_domains.size()) { return false; @@ -1099,7 +1095,7 @@ class AllocateReuseModifier { // Do we have a true pointwise op? // (ie. a TV op, excluding direct assignments and reductions) bool isPointwiseTvOp(const Expr* expr) { - if (ir_utils::isTVOp(expr)) { + if (ir_utils::isTvOp(expr)) { return expr->isA() || expr->isA() || expr->isA(); } @@ -1108,7 +1104,7 @@ class AllocateReuseModifier { // Utility to capture reduction ops bool isReductionTvOp(const Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } return expr->isA() || expr->isA(); @@ -1116,7 +1112,7 @@ class AllocateReuseModifier { // Utility to capture reduction ops bool isBroadcastTvOp(const Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } return expr->isA(); @@ -1138,8 +1134,7 @@ class AllocateReuseModifier { } // namespace -std::vector reuseMemoryAllocations( - const std::vector& exprs) { +std::vector reuseMemoryAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("reuseMemoryAllocations"); bool debug_print = isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo); if (debug_print) { diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.h b/torch/csrc/jit/codegen/cuda/lower_alias_memory.h index 26b33b6d5dc..0d144b9f2f4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.h +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -28,8 +28,7 @@ namespace cuda { //! is not used after this op: //! then alias output Allocate to input Allocate. //! -std::vector reuseMemoryAllocations( - const std::vector& exprs); +std::vector reuseMemoryAllocations(const std::vector& exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 2f70c275832..c03848ccff8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -1,10 +1,8 @@ -#include #include #include #include #include -#include -#include +#include #include #include @@ -17,8 +15,12 @@ namespace cuda { namespace { -class AllocationInserter : public kir::MutableIrVisitor { +class AllocationInserter : public kir::ExprMutator { private: + using kir::ExprMutator::handle; + + // Expanded version of BasicAllocInfo in lower_utils.h helps to track + // additional information struct AllocationInformation { // The for loop that the initialization of this allocation must be // placed in, nullptr if not within a loop @@ -26,7 +28,7 @@ class AllocationInserter : public kir::MutableIrVisitor { // The expression that the initialization of this allocation must // be placed before - kir::Expr* init_place_before = nullptr; + Expr* init_place_before = nullptr; // Keep track of the actual allocation loop. This can be different // from init_for_loop only with unswitched shared memory allocations, @@ -37,143 +39,96 @@ class AllocationInserter : public kir::MutableIrVisitor { // The expression that this allocation must be placed // before. Similar to alloc_for_loop, this is different from // init_place_before only with unswitched shared memory allocations. - kir::Expr* alloc_place_before = nullptr; + Expr* alloc_place_before = nullptr; // The allocation position relative to buffer size_t alloc_pos = 0; // The buffer this allocation is for - kir::TensorView* buffer = nullptr; - - // The allocation expression - kir::Allocate* alloc_expr = nullptr; - - // Initialization - kir::Expr* init_expr = nullptr; + TensorView* buffer = nullptr; // Info to transfer to GPU lower bool has_halo = false; // Local Iterdomains that this allocation covers - std::unique_ptr> allocation_domains; + std::unique_ptr> allocation_domains; }; // Find allocation point - void findAllocationPosition(AllocationInformation& info, kir::Expr* expr) { + // Fills info.buffer, info.alloc_pos, info.init_for_loop, + // info.init_place_before, info.alloc_for_loop, info.alloc_place_before + void fillAllocationInformation(AllocationInformation& info, Expr* expr) { size_t alloc_pos = 0; kir::ForLoop* init_for_loop = nullptr; - auto fuser_tv = info.buffer->fuserTv(); size_t fl_idx_next = 0; - - bool outer_alloc_found = false; - kir::ForLoop* alloc_for_loop = nullptr; - size_t alloc_fl_idx_next = 0; - - for (auto fl : for_loops) { - if (alloc_pos == fuser_tv->getComputeAtPosition()) { - break; - } - - if (fuser_tv->axis(alloc_pos)->isReduction()) { - const auto outputs = - FusionGuard::getCurFusion()->getTerminatingOutputs(); - TORCH_INTERNAL_ASSERT( - std::find(outputs.begin(), outputs.end(), fuser_tv) != - outputs.end(), - "Invalid computeAt of T", - fuser_tv->name(), - ". A reducation axis is detected within computeAt axes even though it is not an output tensor."); - break; - } - - auto fl_id = fl->iter_domain(); - - if (fl_id->parallelType() == ParallelType::Unroll) { - break; - } - - // Shared memory must be allocated outside of unswitched - // domains. See issue #1133. - if (fl_id->parallelType() == ParallelType::Unswitch && - fuser_tv->getMemoryType() == MemoryType::Shared) { - outer_alloc_found = true; - } - - auto local_id = gpu_lower->lowerValue(fuser_tv->axis(alloc_pos)) - ->as(); - - if (gpu_lower->caLoopMap().areMapped(local_id, fl_id)) { - alloc_pos++; - } - - init_for_loop = fl; - ++fl_idx_next; - - if (!outer_alloc_found) { - alloc_for_loop = fl; - ++alloc_fl_idx_next; + auto loop_alloc_info = + loop_utils::getAllocInformation(info.buffer, for_loops_); + + info.init_for_loop = loop_alloc_info.init_for_loop; + info.alloc_for_loop = loop_alloc_info.alloc_for_loop; + info.alloc_pos = loop_alloc_info.alloc_pos; + + auto next_fl = [](kir::ForLoop* fl, const std::vector fls) { + for (auto i : c10::irange(fls.size())) { + if (fl == fls[i]) { + if (i + 1 < fls.size()) { + return fls[i + 1]; + } + } } - } - - info.alloc_pos = alloc_pos; - info.init_for_loop = init_for_loop; + TORCH_INTERNAL_ASSERT(false, "Could not find desired loop."); + }; if (info.init_for_loop == nullptr) { - info.init_place_before = for_loops.size() > 0 ? for_loops[0] : expr; + info.init_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr; } else { - if (info.init_for_loop == for_loops.back()) { + if (info.init_for_loop == for_loops_.back()) { // Inline allocation, place before expr info.init_place_before = expr; } else { // Place allocation after the last computeAt axis // TODO: may be more efficient to place before the first non-computeAt // axis - info.init_place_before = for_loops.at(fl_idx_next); + info.init_place_before = next_fl(info.init_for_loop, for_loops_); } } // Set the allocation loop and the place_before expression in the // same way as the initialization loop and place_before expression - if (!outer_alloc_found) { + if (info.alloc_for_loop == info.init_for_loop) { info.alloc_for_loop = info.init_for_loop; info.alloc_place_before = info.init_place_before; } else { - info.alloc_for_loop = alloc_for_loop; if (info.alloc_for_loop == nullptr) { - info.alloc_place_before = for_loops.size() > 0 ? for_loops[0] : expr; + info.alloc_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr; } else { // Since there must be an inner unswitched domain, // alloc_for_loop should never be the inner-most loop. - TORCH_INTERNAL_ASSERT(info.alloc_for_loop != for_loops.back()); - info.alloc_place_before = for_loops.at(alloc_fl_idx_next); + TORCH_INTERNAL_ASSERT(info.alloc_for_loop != for_loops_.back()); + info.alloc_place_before = next_fl(info.alloc_for_loop, for_loops_); } } } // Create initialization expression if init_val is non-null. - void createInitExpr(AllocationInformation& info, kir::Val* init_val) { + Expr* createInitExpr(AllocationInformation& info, Val* init_val) { if (init_val == nullptr) { - info.init_expr = nullptr; - return; + return nullptr; } - auto fuser_tv = info.buffer->fuserTv(); - - std::vector init_dims; - for (const auto axis_i : c10::irange(info.alloc_pos, fuser_tv->nDims())) { - if (info.buffer->fuserTv()->axis(axis_i)->isReduction() || - info.buffer->fuserTv()->axis(axis_i)->isBroadcast()) { + std::vector init_dims; + for (const auto axis_i : + c10::irange(info.alloc_pos, info.buffer->nDims())) { + if (info.buffer->axis(axis_i)->isReduction() || + info.buffer->axis(axis_i)->isBroadcast()) { continue; } - auto concrete_id = - gpu_lower - ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID( - fuser_tv->axis(axis_i))) - ->as(); + auto concrete_id = gpu_lower->caParallelMap().getConcreteMappedID( + info.buffer->axis(axis_i)); init_dims.push_back(concrete_id); } - kir::Expr* init_expr = ir_builder.create( - UnaryOpType::Set, info.buffer, init_val); + Expr* init_expr = + IrBuilder::create(UnaryOpType::Set, info.buffer, init_val); for (auto init_loop_it = init_dims.rbegin(); init_loop_it != init_dims.rend(); ++init_loop_it) { @@ -181,9 +136,9 @@ class AllocationInserter : public kir::MutableIrVisitor { kir::ForLoop* new_loop = nullptr; auto extent_with_halo = gpu_lower->haloInfo().getExtent(id); if (extent_with_halo) { - new_loop = ir_builder.create( + new_loop = IrBuilder::create( id, - ir_builder.create(c10::nullopt), + IrBuilder::create(c10::nullopt), nullptr, extent_with_halo, nullptr, @@ -191,31 +146,33 @@ class AllocationInserter : public kir::MutableIrVisitor { nullptr, false); } else { - new_loop = ir_builder.create(id); + new_loop = IrBuilder::create(id); } new_loop->body().push_back(init_expr); init_expr = new_loop; } - info.init_expr = init_expr; + return init_expr; } - std::vector getGlobalAllocationSizes(AllocationInformation& info) { + std::vector getGlobalAllocationSizes(AllocationInformation& info) { const auto& domain = info.buffer->domain(); - const auto& maybe_rfactor_domain = - domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain(); + const auto& maybe_rfactor_domain = domain->hasRFactor() + ? domain->getRFactorDomain() + : domain->getRootDomain(); - std::vector alloc_dims; + std::vector alloc_dims; for (const auto id : maybe_rfactor_domain) { if (id->isReduction() || id->isStride() || - id->iterType() == IterType::BroadcastWithoutStride) { + id->getIterType() == IterType::BroadcastWithoutStride) { continue; } auto extent = id->extent(); // Use halo-extended extent if found auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id); if (halo_extent.hasHalo()) { - extent = ir_builder.addExpr(extent, halo_extent.width()); + extent = IrBuilder::addExpr( + extent, IrBuilder::create(halo_extent.width())); } alloc_dims.push_back(extent); } @@ -244,7 +201,7 @@ class AllocationInserter : public kir::MutableIrVisitor { // fall back to the leaf-based allocation. // // See the FusionShiftDoubleSplit test for an example case. - std::vector getNonGlobalAllocExprWithHalo( + std::vector getNonGlobalAllocExprWithHalo( TensorView* tv, const std::vector& alloc_domains) { std::vector start_vals; @@ -255,18 +212,18 @@ class AllocationInserter : public kir::MutableIrVisitor { [](IterDomain* dom) { return dom->as(); }); // Get all exprs involved in generating the allocation IDs - auto exprs = ExprSort::getExprs(tv->fusion(), start_vals); + auto exprs = StmtSort::getExprs(tv->fusion(), start_vals); // Get the halo extent if found auto getExtent = [this](IterDomain* id) { auto extent = gpu_lower->haloInfo().getExtent(id); if (extent == nullptr) { - extent = gpu_lower->lowerValue(id->extent()); + extent = id->extent(); } return extent; }; - std::unordered_map known_extents; + std::unordered_map known_extents; // IterDomains that are allocated fully. For example, if an ID is // split and only one of them is used for allocation, that's not @@ -314,7 +271,7 @@ class AllocationInserter : public kir::MutableIrVisitor { } else { known_extents.insert( {split->in(), - ir_builder.mulExpr(outer_it->second, inner_it->second)}); + IrBuilder::mulExpr(outer_it->second, inner_it->second)}); } known_extents.erase(inner_it); known_extents.erase(outer_it); @@ -330,7 +287,7 @@ class AllocationInserter : public kir::MutableIrVisitor { } } - std::vector alloc_dims; + std::vector alloc_dims; for (auto root_axis : tv->getRootDomain()) { auto it = known_extents.find(root_axis); @@ -355,24 +312,22 @@ class AllocationInserter : public kir::MutableIrVisitor { return alloc_dims; } - std::vector getNonGlobalAllocExpr(AllocationInformation& info) { - auto fuser_tv = info.buffer->fuserTv(); - const auto memory_type = info.buffer->memoryType(); + std::vector getNonGlobalAllocExpr(AllocationInformation& info) { + const auto memory_type = info.buffer->getMemoryType(); TORCH_INTERNAL_ASSERT( memory_type != MemoryType::Global, "Invalid memory type: ", memory_type); - std::vector alloc_dims; + std::vector alloc_dims; bool has_halo = false; std::vector alloc_domains; - info.allocation_domains = std::make_unique>(); + info.allocation_domains = std::make_unique>(); - for (const auto axis_i : c10::irange(fuser_tv->nDims())) { - const auto local_id = - gpu_lower->lowerValue(fuser_tv->axis(axis_i))->as(); + for (const auto axis_i : c10::irange(info.buffer->nDims())) { + const auto local_id = info.buffer->axis(axis_i); // Don't use reduction/stride/broadcast axis in the allocation // computation @@ -381,16 +336,14 @@ class AllocationInserter : public kir::MutableIrVisitor { continue; } - auto concrete_id = - gpu_lower - ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID( - fuser_tv->axis(axis_i))) - ->as(); + auto concrete_id = gpu_lower->caParallelMap().getConcreteMappedID( + info.buffer->axis(axis_i)); const bool is_block_dim = - isParallelTypeBlockDim(concrete_id->parallelType()); + isParallelTypeBlockDim(concrete_id->getParallelType()); const bool is_thread_dim = - isParallelTypeThreadDim(concrete_id->parallelType()); - const bool is_thread = isParallelTypeThread(concrete_id->parallelType()); + isParallelTypeThreadDim(concrete_id->getParallelType()); + const bool is_thread = + isParallelTypeThread(concrete_id->getParallelType()); if (axis_i < info.alloc_pos) { // Even when the axis is outside the allocation position, if the @@ -403,7 +356,7 @@ class AllocationInserter : public kir::MutableIrVisitor { (memory_type == MemoryType::Global && is_thread))) { continue; } - alloc_domains.push_back(fuser_tv->axis(axis_i)); + alloc_domains.push_back(info.buffer->axis(axis_i)); } else { if ( // If shared memory, don't use any IDs bound to a grid dimension @@ -413,12 +366,13 @@ class AllocationInserter : public kir::MutableIrVisitor { (memory_type == MemoryType::Local && is_thread)) { continue; } - alloc_domains.push_back(fuser_tv->axis(axis_i)); + alloc_domains.push_back(info.buffer->axis(axis_i)); } auto extent = concrete_id->extent(); - if (gpu_lower->haloInfo().getExtent(fuser_tv->axis(axis_i)) != nullptr) { + if (gpu_lower->haloInfo().getExtent(info.buffer->axis(axis_i)) != + nullptr) { has_halo = true; } @@ -430,20 +384,19 @@ class AllocationInserter : public kir::MutableIrVisitor { // the halo extents from leaf IDs to root IDs if (has_halo) { info.has_halo = true; - return getNonGlobalAllocExprWithHalo(fuser_tv, alloc_domains); + return getNonGlobalAllocExprWithHalo(info.buffer, alloc_domains); } return alloc_dims; } - void createAllocExpr(AllocationInformation& info, bool is_output) { + kir::Allocate* createAllocExpr(AllocationInformation& info, bool is_output) { if (is_output) { - info.alloc_expr = nullptr; - return; + return nullptr; } - std::vector alloc_dims; - const MemoryType memory_type = info.buffer->memoryType(); + std::vector alloc_dims; + const MemoryType memory_type = info.buffer->getMemoryType(); if (memory_type == MemoryType::Global) { alloc_dims = getGlobalAllocationSizes(info); @@ -453,60 +406,74 @@ class AllocationInserter : public kir::MutableIrVisitor { if (alloc_dims.size() == 0 && info.buffer->domain()->noReductions().size() != 0) { - alloc_dims.push_back(ir_builder.create(1)); + alloc_dims.push_back(info.buffer->container()->oneVal()); + } + + // Double the allocation size if double-buffered. Record the + // original size for indexing. + if (info.buffer->isDoubleBuffered()) { + Val* original_alloc_size = nullptr; + for (auto alloc_dim : alloc_dims) { + if (original_alloc_size == nullptr) { + original_alloc_size = alloc_dim; + } else { + original_alloc_size = + IrBuilder::mulExpr(original_alloc_size, alloc_dim); + } + } + GpuLower::current()->doubleBufferInfo().setOriginalAllocSize( + info.buffer, original_alloc_size); + alloc_dims.push_back(IrBuilder::create(2)); } // Create the allocation node - info.alloc_expr = ir_builder.create( - info.buffer, info.buffer->memoryType(), alloc_dims); + return IrBuilder::create( + info.buffer, info.buffer->getMemoryType(), alloc_dims); } - void handle(kir::Expr* expr) { - if (!ir_utils::isTVOp(expr) || expr->isA()) { - expr->accept(this); + void handle(Expr* expr) override { + if (!ir_utils::isTvOp(expr) || expr->isA()) { + ExprMutator::handle(expr); return; } // // Found where the allocation needs to be inserted for (auto out : expr->outputs()) { - if (!out->isA()) { + if (!out->isA()) { continue; } - auto out_tv = out->as(); - auto default_val = - gpu_lower->predicateElimination().getInitValue(out_tv->fuserTv()); + auto out_tv = out->as(); + auto default_val = gpu_lower->predicateElimination().getInitValue(out_tv); - kir::Val* init = nullptr; - if (expr->isA() && out_tv->fuserTv()->hasReduction()) { + Val* init = nullptr; + if (expr->isA() && out_tv->hasReduction()) { TORCH_INTERNAL_ASSERT( default_val == nullptr, "Reduction should not have a default initialization value for predicate elimination."); - init = expr->as()->init(); - } else if (expr->isA()) { + init = expr->as()->init(); + } else if (expr->isA()) { TORCH_INTERNAL_ASSERT( default_val == nullptr, "Welford should not have a default initialization value for predicate elimination."); - const auto welford = expr->as(); - if (out->id() == welford->outVar()->id()) { - init = welford->initVar() == nullptr - ? ir_builder.create(0) - : welford->initVar(); - } else if (out->id() == welford->outAvg()->id()) { - init = welford->initAvg() == nullptr - ? ir_builder.create(0) - : welford->initAvg(); + const auto welford = expr->as(); + if (out->name() == welford->outVar()->name()) { + init = welford->initVar() == nullptr ? IrBuilder::create(0) + : welford->initVar(); + } else if (out->name() == welford->outAvg()->name()) { + init = welford->initAvg() == nullptr ? IrBuilder::create(0) + : welford->initAvg(); } else { TORCH_INTERNAL_ASSERT( - out->id() == welford->outN()->id(), "Unreachable"); + out->name() == welford->outN()->name(), "Unreachable"); init = welford->initN(); } } else if (default_val != nullptr) { init = default_val; } - const bool is_output = gpu_lower->kernel()->isOutput(out); + const bool is_output = out->isFusionOutput(); // Don't need to alloc outputs, and if we don't need to initialize we're // done. @@ -516,150 +483,91 @@ class AllocationInserter : public kir::MutableIrVisitor { AllocationInformation allocation; allocation.buffer = out_tv; - findAllocationPosition(allocation, expr); - createAllocExpr(allocation, is_output); - createInitExpr(allocation, init); + fillAllocationInformation(allocation, expr); + + auto alloc_expr = createAllocExpr(allocation, is_output); + auto init_expr = createInitExpr(allocation, init); // Write information to GPULower - writeInfoToGPULower(allocation); + writeInfoToGPULower(allocation, alloc_expr); + + // Register allocations before initializations to keep them in the right + // order + if (alloc_expr != nullptr) { + if (allocation.buffer->getMemoryType() == MemoryType::Shared) { + // Shared allocations go at the begining of scope + TORCH_INTERNAL_ASSERT(!exprs_.empty()); + registerInsertBefore(exprs_[0], alloc_expr, nullptr); + } else { + TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); + kir::Scope* scope = allocation.alloc_for_loop == nullptr + ? nullptr + : &allocation.alloc_for_loop->body(); + registerInsertBefore( + allocation.alloc_place_before, alloc_expr, scope); + } + } - allocs.push_back(std::move(allocation)); + if (init_expr != nullptr) { + TORCH_INTERNAL_ASSERT(allocation.init_place_before != nullptr); + kir::Scope* scope = allocation.init_for_loop == nullptr + ? nullptr + : &allocation.init_for_loop->body(); + registerInsertBefore(allocation.init_place_before, init_expr, scope); + } } } - void writeInfoToGPULower(const AllocationInformation& allocation) { + // Sends alloc_expr, info.has_halo, info.allocation_domains to GpuLower + void writeInfoToGPULower( + const AllocationInformation& allocation, + kir::Allocate* alloc_expr) { auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); - if (allocation.alloc_expr == nullptr) { + if (alloc_expr == nullptr) { // Skip output allocation. return; } TORCH_INTERNAL_ASSERT( - !lower_alloc_info_map.count(allocation.alloc_expr), + !lower_alloc_info_map.count(alloc_expr), "duplicated allocation info entry"); // Create info entry for GPULower auto lower_alloc_info_ptr = std::make_unique(); - lower_alloc_info_ptr->alloc_expr = allocation.alloc_expr; + lower_alloc_info_ptr->alloc_expr = alloc_expr; lower_alloc_info_ptr->has_halo = allocation.has_halo; if (allocation.allocation_domains) { lower_alloc_info_ptr->alloc_domains = *(allocation.allocation_domains); } // Write entry to the stored map - lower_alloc_info_map[allocation.alloc_expr] = - std::move(lower_alloc_info_ptr); + lower_alloc_info_map[alloc_expr] = std::move(lower_alloc_info_ptr); } - void visit(kir::ForLoop* fl) final { - for_loops.push_back(fl); - // Modifying in place, make a copy of the vector - const std::vector exprs = fl->body().exprs(); - for (auto expr : exprs) { - handle(expr); - } - for_loops.pop_back(); - } - - void visit(kir::IfThenElse*) final { + void handle(kir::IfThenElse*) final { TORCH_INTERNAL_ASSERT( false, "Pass does not support conditional statements, ", "this pass should be run before any conditionals are placed in code."); } - AllocationInserter(std::vector _loop_nests) - : loop_nests_(std::move(_loop_nests)), - gpu_lower(GpuLower::current()), - ir_builder(gpu_lower->kernel()) { - // Compute all allocations - const std::vector exprs = loop_nests_; - for (auto expr : exprs) { - handle(expr); - } - - // First, place allocations of dynamic smem tensors at the very - // beginning of the expr list. Traverse backward as they should be - // placed in topological order. - for (auto it = allocs.rbegin(); it != allocs.rend(); ++it) { - const auto& alloc = *it; - if (alloc.alloc_expr == nullptr) { - continue; - } - // Dynamic smem exprs need to be at the begining of the kernel outside for - // loops - if (alloc.buffer->memoryType() == MemoryType::Shared && - !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) { - loop_nests_.insert(loop_nests_.begin(), alloc.alloc_expr); - } - } - - // Place the remaining allocations. - for (const auto& alloc : allocs) { - if (alloc.alloc_expr == nullptr) { - continue; - } - if (alloc.buffer->memoryType() == MemoryType::Shared && - !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) { - continue; - } - if (alloc.alloc_for_loop == nullptr) { - auto place_before_it = std::find( - loop_nests_.begin(), loop_nests_.end(), alloc.alloc_place_before); - TORCH_INTERNAL_ASSERT( - place_before_it != loop_nests_.end(), - "Could not figure out where to place allocation. ", - "Use of the buffer, ", - toString(alloc.buffer), - ", could not be found.", - toString(alloc.alloc_place_before)); - loop_nests_.insert(place_before_it, alloc.alloc_expr); - } else { - alloc.alloc_for_loop->body().insert_before( - alloc.alloc_place_before, alloc.alloc_expr); - } - } - - // Now that allocations are in place, place the initializations - for (const auto& alloc : allocs) { - if (alloc.init_expr == nullptr) { - continue; - } - if (alloc.init_for_loop == nullptr) { - auto place_before_it = std::find( - loop_nests_.begin(), loop_nests_.end(), alloc.init_place_before); - // Don't need a check here as if the allocation placement succeeded - // this will too - loop_nests_.insert(place_before_it, alloc.init_expr); - } else { - alloc.init_for_loop->body().insert_before( - alloc.init_place_before, alloc.init_expr); - } - } + AllocationInserter(const std::vector& exprs) + : gpu_lower(GpuLower::current()) { + kir::ExprMutator::traverseAndInsert(exprs); } private: - std::deque allocs; - - std::vector for_loops; - - std::vector loop_nests_; - GpuLower* gpu_lower; - kir::IrBuilder ir_builder; - public: - static std::vector insert( - const std::vector& loop_nests) { - AllocationInserter inserter(loop_nests); - return inserter.loop_nests_; + static std::vector insert(const std::vector& exprs) { + AllocationInserter inserter(exprs); + return inserter.exprs_; } }; } // namespace -std::vector insertAllocations( - const std::vector& exprs) { +std::vector insertAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertAllocations"); return AllocationInserter::insert(exprs); } diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.h b/torch/csrc/jit/codegen/cuda/lower_allocation.h index bc0344ca19f..45ebeac03f7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.h +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.h @@ -1,8 +1,7 @@ #pragma once -#include +#include -#include #include #include @@ -17,7 +16,7 @@ namespace cuda { //! logic duplication struct LocalAllocationInfo { kir::Allocate* alloc_expr = nullptr; - std::vector alloc_domains; + std::vector alloc_domains; bool has_halo = false; }; @@ -25,7 +24,7 @@ using LocalAllocationInfoMap = std::unordered_map>; //! Insert buffer allocations -std::vector insertAllocations(const std::vector& exprs); +std::vector insertAllocations(const std::vector& exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp new file mode 100644 index 00000000000..c8110413de7 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -0,0 +1,508 @@ +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +unsigned int getDoubleBufferAxisPosition(const TensorView* tv) { + // Double-buffering prefetches the next subregion of the tensor by + // doubling the allocation. The subregion is defined by the axes + // at the CA position till the inner-most position. There must be + // at least one axis that is outside (left) of the CA position, + // which defines the loop where prefetching is applied. Therefore, + // the CA position must be larger than 0. + + TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0); + + // Unroll must not exist outside of double-buffer axis + auto first_unroll_it = std::find_if( + tv->domain()->domain().begin(), + tv->domain()->domain().end(), + [](const auto axis) { + return axis->getParallelType() == ParallelType::Unroll; + }); + + const int first_unroll_pos = + std::distance(tv->domain()->domain().begin(), first_unroll_it); + + const int unroll_or_ca_pos = + std::min((int)tv->getComputeAtPosition(), first_unroll_pos); + + TORCH_INTERNAL_ASSERT( + unroll_or_ca_pos > 0, + "Invalid tensor to double-buffer. Valid double buffer axis not found due to Unroll. ", + tv->toString()); + + int valid_pos = -1; + // Skip parallelized or broadcast axes + for (int i = unroll_or_ca_pos - 1; i >= 0; --i) { + auto pt = tv->axis(i)->getParallelType(); + if (!isParallelTypeThread(pt) && !tv->axis(i)->isBroadcast()) { + valid_pos = i; + break; + } + } + + TORCH_INTERNAL_ASSERT( + valid_pos >= 0, + "Invalid tensor to double-buffer. Valid double buffer axis not found. ", + tv->toString()); + + return valid_pos; +} + +IterDomain* getDoubleBufferAxis(const TensorView* tv) { + return tv->axis((int)getDoubleBufferAxisPosition(tv)); +} + +void validateDoubleBufferedTensor(const TensorView* tv) { + auto double_buffer_pos = getDoubleBufferAxisPosition(tv); + + // Like vectorization, only UnaryOp::Set with another TensorView is + // considered. + auto def = tv->definition(); + TORCH_INTERNAL_ASSERT( + def->isA() && + def->as()->getUnaryOpType() == UnaryOpType::Set, + "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: ", + def->toString()); + + TORCH_INTERNAL_ASSERT( + def->as()->in()->isA(), + "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set with TensorView is supported: ", + def->toString()); + + // Require the producer tensor to have been computed entirely for + // the double-buffering loop. Otherwise, the producer itself would + // also need to be double-bufferred. + auto producer = def->as()->in()->as(); + TORCH_INTERNAL_ASSERT( + producer->getComputeAtPosition() <= double_buffer_pos, + "Invalid tensor to double-buffer. The computeAt position of the producer tensor must be moved left: ", + producer->toString()); + + // Not strictly necessary, but only gmem -> smem or local and smem -> local + // are allowed. + const auto p_mem_type = producer->getMemoryType(); + const auto c_mem_type = tv->getMemoryType(); + TORCH_INTERNAL_ASSERT( + (p_mem_type == MemoryType::Global && + (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || + (p_mem_type == MemoryType::Shared && c_mem_type == MemoryType::Local), + "Invalid tensor to double-buffer: ", + tv->toString(), + ". Producer memory type: ", + p_mem_type, + ". Consumer memory type: ", + c_mem_type); + + return; +} + +namespace { + +// Initial inspection of a fusion to find and validate double buffered tensors +class DoubleBufferFusionInspector : private IterVisitor { + public: + DoubleBufferFusionInspector(Fusion* fusion, DoubleBufferInfo& db_info) + : db_info_(db_info) { + traverse(fusion); + } + + private: + using IterVisitor::handle; + + void handle(TensorView* tv) final { + if (!tv->isDoubleBuffered()) { + return; + } + + validateDoubleBufferedTensor(tv); + + auto db_axis = getDoubleBufferAxis(tv); + + db_info_.setDoubleBufferAxis(tv, db_axis); + } + + private: + DoubleBufferInfo& db_info_; +}; + +// The type of replicated double-buffer loops +enum class LoopType { Prologue, Main, Epilogue }; + +// The epilogue loop is only created when the producer of a double +// buffer tensor is on smem, in which case it would otherwise require +// an additional predicate to guard buffer overruns. When it's on +// gmem, that isn't the case, so it does not need to create an +// epilogue loop. +bool requireEpilogue(const std::vector& exprs) { + return std::any_of(exprs.begin(), exprs.end(), [](const UnaryOp* uop) { + return uop->in()->as()->getMemoryType() == MemoryType::Shared; + }); +} + +// Replicates double buffer loops for Prologue, Main, and +// Epilogue. Prologue only copies the load expressions of double +// buffered tensors, whereas Epilogue does any expression other than +// the loads. Main copies everything. +class DoubleBufferLoopCloner : public kir::IrVisitor { + public: + static kir::ForLoop* clone( + kir::ForLoop* double_buffer_loop, + const std::vector& double_buffer_load_exprs, + LoopType loop_type) { + DoubleBufferLoopCloner cloner( + double_buffer_loop, double_buffer_load_exprs, loop_type); + cloner.clone(); + return cloner.cloned_top_level_loop_; + } + + private: + DoubleBufferLoopCloner( + kir::ForLoop* double_buffer_loop, + const std::vector& double_buffer_load_exprs, + LoopType loop_type) + : double_buffer_loop_(double_buffer_loop), + double_buffer_load_exprs_(double_buffer_load_exprs), + loop_type_(loop_type) {} + + using kir::IrVisitor::handle; + + void clone() { + const auto gpu_lower = GpuLower::current(); + + // Cloning the double buffer loop as follows: + // + // Prologue: 0 to 1 + // Main: 0 to (extent-1) + // Epilogue: (extent-1) to extent + + auto index = IrBuilder::create(c10::nullopt); + auto start = double_buffer_loop_->start(); + auto stop = double_buffer_loop_->stop(); + + if (loop_type_ == LoopType::Prologue) { + TORCH_INTERNAL_ASSERT(start->isZeroInt()); + stop = gpu_lower->kernel()->oneVal(); + } else if ( + loop_type_ == LoopType::Main && + requireEpilogue(double_buffer_load_exprs_)) { + stop = IrBuilder::subExpr( + double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); + } else if (loop_type_ == LoopType::Epilogue) { + TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_)); + start = IrBuilder::subExpr( + double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); + } + + cloned_top_level_loop_ = IrBuilder::create( + double_buffer_loop_->iter_domain(), + index, + start, + stop, + gpu_lower->kernel()->oneVal(), + false, + nullptr, + double_buffer_loop_->isUnrollRequired()); + + handle(double_buffer_loop_); + } + + void handle(kir::ForLoop* fl) final { + const auto gpu_lower = GpuLower::current(); + + kir::ForLoop* cloned_loop = fl == double_buffer_loop_ + ? cloned_top_level_loop_ + : IrBuilder::create(fl); + + cloned_scopes_.push_back(&cloned_loop->body()); + + kir::IrVisitor::handle(fl); + + cloned_scopes_.pop_back(); + + // Add the cloned loop into the parent loop body only when the + // cloned loop contains expressions. + if (!cloned_loop->body().empty() && !cloned_scopes_.empty()) { + cloned_scopes_.back()->push_back(cloned_loop); + } + } + + void handle(kir::IfThenElse* ite) final { + TORCH_INTERNAL_ASSERT(false, "No IfThenElse should exist yet"); + } + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + TORCH_INTERNAL_ASSERT(!cloned_scopes_.empty()); + + if (loop_type_ == LoopType::Main) { + cloned_scopes_.back()->push_back(expr); + return; + } + + // In Prologue and Epilogue, either load expressions or anything + // else are copied. Note that there can be multiple exprs defining + // double buffered TVs (e.g., buffer initialization). + + auto out_tv = ir_utils::getTvOutput(expr); + const auto is_double_buffer_load_expr = std::any_of( + double_buffer_load_exprs_.begin(), + double_buffer_load_exprs_.end(), + [out_tv](const auto load_expr) { + auto double_buffer_tv = ir_utils::getTvOutput(load_expr); + TORCH_INTERNAL_ASSERT(double_buffer_tv != nullptr); + return out_tv == double_buffer_tv; + }); + if ((loop_type_ == LoopType::Prologue && is_double_buffer_load_expr) || + (loop_type_ == LoopType::Epilogue && !is_double_buffer_load_expr)) { + cloned_scopes_.back()->push_back(expr); + } + } + + private: + kir::ForLoop* double_buffer_loop_ = nullptr; + const std::vector& double_buffer_load_exprs_; + const LoopType loop_type_; + + kir::ForLoop* cloned_top_level_loop_ = nullptr; + std::deque cloned_scopes_; +}; + +using InsertionInfo = std::unordered_map>; + +// Traverse lowered loop-nests and find all double buffer loops and +// associated load expressions. +class DoubleBufferLoopNestInspector : private kir::IrVisitor { + public: + static InsertionInfo run(const std::vector& exprs) { + DoubleBufferLoopNestInspector inspector(exprs); + return inspector.insertion_info_; + } + + private: + DoubleBufferLoopNestInspector(const std::vector& exprs) { + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(UnaryOp* uop) final { + const auto gpu_lower = GpuLower::current(); + + auto out_tv = ir_utils::getTvOutput(uop); + + if (out_tv == nullptr) { + return; + } + + // Ignore init loop + if (!out_tv->isDoubleBuffered() || !uop->in()->isA()) { + return; + } + + auto double_buffer_loop = + gpu_lower->doubleBufferInfo().getDoubleBufferLoop(out_tv, for_loops_); + + TORCH_INTERNAL_ASSERT( + double_buffer_loop != nullptr, + "No double buffer loop found for a double buffered tensor: ", + out_tv->toString()); + + validateDoubleBufferLoop(double_buffer_loop); + + insertion_info_[double_buffer_loop].push_back(uop); + } + + static void validateDoubleBufferLoop(kir::ForLoop* loop) { + TORCH_INTERNAL_ASSERT( + loop->start()->isZeroInt(), "Unsupported loop: ", loop->toString()); + TORCH_INTERNAL_ASSERT( + loop->step()->isOneInt(), "Unsupported loop: ", loop->toString()); + TORCH_INTERNAL_ASSERT( + !loop->vectorize(), + "Vectorized loop should not be the allocation loop for double-buffered tensor: ", + loop->toString()); + TORCH_INTERNAL_ASSERT( + !loop->vectorize_shift(), + "Vectorize shift loop should not be the allocation loop for double-buffered tensor: ", + loop->toString()); + } + + InsertionInfo insertion_info_; +}; + +// Apply double buffering transformations +class DoubleBufferInserter : private kir::ExprMutator { + public: + // When there exist multiple double buffer loops, apply + // transformations to inner-most loops first. A single ExprMutator + // pass can only process one loop. + static std::vector run( + const std::vector& exprs, + InsertionInfo insertion_info) { + auto inserted_exprs = exprs; + while (!insertion_info.empty()) { + DoubleBufferInserter inserter(inserted_exprs, insertion_info); + inserted_exprs = inserter.exprs_; + } + return inserted_exprs; + } + + private: + DoubleBufferInserter( + const std::vector& exprs, + InsertionInfo& insertion_info) + : insertion_info_(insertion_info) { + auto num_double_buffer_loops = insertion_info.size(); + traverseAndInsert(exprs); + TORCH_INTERNAL_ASSERT(processed_loop_ != nullptr); + TORCH_INTERNAL_ASSERT(insertion_info.size() == num_double_buffer_loops - 1); + } + + using kir::ExprMutator::handle; + + void handle(kir::ForLoop* loop) final { + kir::ExprMutator::handle(loop); + + // If another loop is already taken care of, no more loop should + // be done in the same pass + if (processed_loop_ != nullptr) { + return; + } + + auto it = insertion_info_.find(loop); + if (it == insertion_info_.end()) { + return; + } + + insert(loop, it->second); + processed_loop_ = loop; + insertion_info_.erase(loop); + } + + void insert( + kir::ForLoop* double_buffer_loop, + const std::vector& loads) { + auto prologue_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, LoopType::Prologue); + registerInsertBefore(double_buffer_loop, prologue_loop); + + auto write_to_smem = + std::any_of(loads.begin(), loads.end(), [](const UnaryOp* uop) { + return uop->out()->as()->getMemoryType() == + MemoryType::Shared; + }); + + // RAW sync is not inserted for double buffered tensors. The only + // exception is the prologue load. + if (write_to_smem) { + auto sync = IrBuilder::create(); + registerInsertBefore(double_buffer_loop, sync); + } + + auto main_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, LoopType::Main); + registerReplace(double_buffer_loop, main_loop); + + if (requireEpilogue(loads)) { + auto epilogue_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, LoopType::Epilogue); + registerInsertAfter(double_buffer_loop, epilogue_loop); + } + } + + private: + InsertionInfo& insertion_info_; + kir::ForLoop* processed_loop_ = nullptr; +}; + +} // namespace + +void DoubleBufferInfo::build(Fusion* fusion) { + DoubleBufferFusionInspector inspector(fusion, *this); +} + +DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) { + TORCH_INTERNAL_ASSERT( + tv->isDoubleBuffered(), "Not a double-buffered tensor: ", tv->toString()); + return map_[tv]; +} + +void DoubleBufferInfo::setDoubleBufferAxis( + const TensorView* tv, + IterDomain* axis) { + getTvInfo(tv).double_buffer_axis = axis; +} + +IterDomain* DoubleBufferInfo::getDoubleBufferAxis(const TensorView* tv) { + if (!tv->isDoubleBuffered()) { + return nullptr; + } + + return getTvInfo(tv).double_buffer_axis; +} + +kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop( + IterDomain* axis, + const std::vector& loops, + bool ignore_prologue) { + auto loop_it = std::find_if(loops.begin(), loops.end(), [&](const auto loop) { + return GpuLower::current()->caIndexMap().areMapped( + loop->iter_domain(), axis) && + (!ignore_prologue || !loop->stop()->isOneInt()); + }); + + if (loop_it != loops.end()) { + return *loop_it; + } else { + return nullptr; + } +} + +kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop( + const TensorView* tv, + const std::vector& loops, + bool ignore_prologue) { + auto axis = getDoubleBufferAxis(tv); + + if (axis == nullptr) { + return nullptr; + } + + return getDoubleBufferLoop(axis, loops, ignore_prologue); +} + +void DoubleBufferInfo::setOriginalAllocSize( + const TensorView* tv, + Val* original_alloc_size) { + getTvInfo(tv).original_alloc_size = original_alloc_size; +} + +Val* DoubleBufferInfo::getOriginalAllocSize(const TensorView* tv) { + if (!tv->isDoubleBuffered()) { + return nullptr; + } + + return getTvInfo(tv).original_alloc_size; +} + +std::vector DoubleBufferPass::run(const std::vector& exprs) { + auto insertion_info = DoubleBufferLoopNestInspector::run(exprs); + return DoubleBufferInserter::run(exprs, insertion_info); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h new file mode 100644 index 00000000000..96bc247f4ff --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -0,0 +1,142 @@ +#pragma once + +#include + +#include +#include +#include + +// Double buffering a tensor doubles its allocation size and uses two +// buffers to facilitate computation and memory access +// overlapping. The basic form of code looks like as follows: +// +// Before: +// for i +// x[S]; // allocation +// for j: +// x[j] = y[i, j] +// for j: +// ... = x[j] +// +// After: +// X[S * 2]; // allocation +// for i in 0 to 1: // Prologue +// for j: +// x[j] = y[i, j] +// +// for i in 0 to N-1: // Main +// for j: +// x[j + (1 - i % 2) * S] = y[i + 1, j] +// for j: +// ... = x[j + (i % 2) * S] +// +// for i in N-1 to N: // Epilogue +// for j: +// ... = x[j + (i % 2) * S] +// +// Here, S is the original size of tensor x. +// +// The i loop is the double buffer loop of tensor x, where double +// buffering is applied to the tensor. The first step of lowering is +// to find the double buffering axis for each double buffered +// tensor. It must not be parallelized as it isn't possible to double +// buffer parallelized loops. Also, an unrolled axis expands the +// allocation and is intended to make the loop completely unrolled, +// which also conflicts with double buffering. So, basically, the double +// buffering axis is the inner-most axis within the axes left +// of the CA position. However, when it is parallelized or unrolled, a +// further left axis is picked. +// +// Once the double buffer axis is determined, the main task is to +// replicate the corresponding double buffer loop as illustrated +// above. The Prologue loop is to just fetch the first element to +// populate the buffer. The main loop is mostly the same as the +// original loop, except for the indexing change to switch the two +// buffers. When used as a consumer, an offset of (1 - i % 2) * S is +// added, whereas (i % 2) * S is added when used as a producer. Here, +// i is the index of the double buffer loop. The Epilogue loop is just +// for the last iteration of the loop. Since the main loop reads one +// element ahead of the producer of the double buffered tensor, it +// would require an additional guard to prevent buffer overruns with +// the producer if the main loop were also used for the last +// iteration. However, the value loaded by the invalid load would not +// be used, so instead of adding the additional predicate, the Epilogue +// loop is replicated from the original loop, except for the load +// expression since it's not used. Note that this overrun does not +// happen when the producer is on gmem, so in that case, this +// additional replication is not done. +// +// When creating those three types of loops, additional care must be +// taken when multiple tensors are double buffered. When multiple +// tensors use the same loop as their double buffer loop, one pass of +// replication takes care of them at once, meaning the same Prologue, +// Main, Epilogue loops are used for the multiple tensors. +// +// Other tasks to do for a double buffer tensor include: +// - Move allocation to outside of the double buffer loop +// - Double the allocation size +// - Omit the RAW sync in the Main and Epilogue loops + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +unsigned int getDoubleBufferAxisPosition(const TensorView* tv); + +IterDomain* getDoubleBufferAxis(const TensorView* tv); + +void validateDoubleBufferedTensor(const TensorView* tv); + +class TORCH_CUDA_CU_API DoubleBufferPass { + public: + //! Apply double buffering transformations + static std::vector run(const std::vector& exprs); +}; + +class TORCH_CUDA_CU_API DoubleBufferInfo { + // Lowering information of double buffered tensors. + struct TvInfo { + IterDomain* double_buffer_axis = nullptr; + Val* original_alloc_size = nullptr; + }; + + public: + void build(Fusion* fusion); + + void setDoubleBufferAxis(const TensorView* tv, IterDomain* id); + + IterDomain* getDoubleBufferAxis(const TensorView* tv); + + //! Get a loop that matches with a given double-buffer axis. If + //! ignore_prologue is true, a matched loop is ignored if it's a + //! prologue loop. + static kir::ForLoop* getDoubleBufferLoop( + IterDomain* axis, + const std::vector& loops, + bool ignore_prologue = false); + + //! Get a loop that matches with the double-buffer axis of a given + //! double-buffered tensor. If ignore_prologue is true, a matched + //! loop is ignored if it's a prologue loop. + kir::ForLoop* getDoubleBufferLoop( + const TensorView* tv, + const std::vector& loops, + bool ignore_prologue = false); + + void setOriginalAllocSize(const TensorView* tv, Val* size); + + Val* getOriginalAllocSize(const TensorView* tv); + + private: + TvInfo& getTvInfo(const TensorView* tv); + + private: + //! Keeps track of information for lowering double buffered tensors + std::unordered_map map_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp index 2353ea9bbf5..84c72c08185 100644 --- a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp @@ -541,7 +541,7 @@ ExprGroup* ExprSegmentationSorter::makeEmptyGroup() { ExprGroup* ExprSegmentationSorter::makeEmptyGroup(Expr* expr) { auto group = makeEmptyGroup(); group->exprs().push_back(expr); - if (ir_utils::isTVOp(expr)) { + if (ir_utils::isTvOp(expr)) { auto out_tv = expr->outputs()[0]->as(); // Grab all id's that are shared with other tensors. for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) { @@ -721,7 +721,7 @@ std::vector getLocalDomainOrdering( std::unordered_set domains; for (auto expr : exprs) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { continue; } diff --git a/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp new file mode 100644 index 00000000000..fa84d1006a1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp @@ -0,0 +1,119 @@ +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +// Replace trivial reductions with unary ops. +class TrivialReductionReplacement : private OptOutMutator { + public: + TrivialReductionReplacement( + Fusion* fusion, + const TrivialReductionInfo& trivial_reduction_info) + : trivial_reduction_info_(trivial_reduction_info) { + FusionGuard fg(fusion); + auto exprs = StmtSort::getExprs(fusion); + for (auto expr : exprs) { + mutate(expr); + } + } + + private: + using OptOutMutator::mutate; + void mutate(ReductionOp* rop) final { + if (rop->out()->isA()) { + auto out_tv = rop->out()->as(); + if (std::all_of( + out_tv->domain()->domain().begin(), + out_tv->domain()->domain().end(), + [&](IterDomain* id) { + // If id is a reduction axis, is it a trivial reduction? + if (id->isReduction()) { + return trivial_reduction_info_.isDerived(id); + } else { + return true; + } + })) { + auto out = rop->out(); + auto in = rop->in(); + auto container = out->container(); + removeExpr(container, rop); + IrBuilder::create(container, UnaryOpType::Set, out, in); + } + } + } + + const TrivialReductionInfo& trivial_reduction_info_; +}; + +// Replaces Transpose, Shift, Gather, and View Ops with Unary Ops. +class UnaryOpInserter : private kir::ExprMutator { + public: + static std::vector insert(const std::vector& exprs) { + UnaryOpInserter inserter(exprs); + return inserter.exprs_; + } + + private: + using kir::ExprMutator::handle; + + UnaryOpInserter(const std::vector& exprs) { + kir::ExprMutator::traverseAndInsert(exprs); + } + + void handle(TransposeOp* top) final { + auto out = top->out(); + auto in = top->in(); + auto container = out->container(); + registerReplace( + top, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } + + void handle(ShiftOp* sop) final { + auto out = sop->out(); + auto in = sop->in(); + auto container = out->container(); + registerReplace( + sop, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } + + void handle(GatherOp* gop) final { + auto out = gop->out(); + auto in = gop->in(); + auto container = out->container(); + registerReplace( + gop, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } + + void handle(ViewOp* vop) final { + auto out = vop->out(); + auto in = vop->in(); + auto container = out->container(); + registerReplace( + vop, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } +}; + +} // namespace + +void trivialReductionReplacement( + Fusion* fusion, + const TrivialReductionInfo& trivial_reduction_info) { + TrivialReductionReplacement replacement(fusion, trivial_reduction_info); +} + +// Transpose, Shift, Gather, and View Ops with Unary Set Ops +std::vector unarySetOpInserter(const std::vector& exprs) { + return UnaryOpInserter::insert(exprs); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h new file mode 100644 index 00000000000..e18f4a8f077 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Replaces trivial reductions with Unary Set Ops +void trivialReductionReplacement(Fusion*, const TrivialReductionInfo&); + +// Transpose, Shift, Gather, and View Ops with Unary Set Ops +std::vector unarySetOpInserter(const std::vector& exprs); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index d92dd279b17..b0ef14079c4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include @@ -13,30 +12,24 @@ namespace jit { namespace fuser { namespace cuda { -IndexLowering::IndexLowering() : ir_builder_(GpuLower::current()->kernel()) {} - -kir::Val* IndexLowering::lowerSrcIndex(kir::Val* src, kir::Val* dst) const { - if (auto tv = dynamic_cast(src)) { - TORCH_INTERNAL_ASSERT(dst->isA()); - return Index::getProducerIndex( - tv->fuserTv(), - dst->as()->fuserTv(), - scope_utils::getLoops(active_scope_expr_)); +Val* IndexLowering::lowerSrcIndex(Val* src, Val* dst) const { + if (auto tv = dynamic_cast(src)) { + TORCH_INTERNAL_ASSERT(dst->isA()); + return Index::getProducerIndex(tv, dst->as(), for_loops_); } else { return src; } } -kir::Val* IndexLowering::lowerDstIndex(kir::Val* dst) const { - if (auto tv = dynamic_cast(dst)) { - return Index::getConsumerIndex( - tv->fuserTv(), scope_utils::getLoops(active_scope_expr_)); +Val* IndexLowering::lowerDstIndex(Val* dst) const { + if (auto tv = dynamic_cast(dst)) { + return Index::getConsumerIndex(tv, for_loops_); } else { return dst; } } -void IndexLowering::pushBack(kir::Expr* expr) { +void IndexLowering::pushBack(Expr* expr) { if (active_scope_ == nullptr) { lowered_exprs_.push_back(expr); } else { @@ -44,78 +37,71 @@ void IndexLowering::pushBack(kir::Expr* expr) { } } -void IndexLowering::visit(const kir::IfThenElse* ite) { - const auto prev_scope_expr = active_scope_expr_; +void IndexLowering::handle(const kir::IfThenElse* ite) { const auto prev_scope = active_scope_; - // TODO(kir): try to avoid recreating new nodes and leaving old ones around - auto new_ite = ir_builder_.create(ite->predicate()); + auto new_ite = IrBuilder::create(ite->predicate()); pushBack(new_ite); - active_scope_expr_ = new_ite; active_scope_ = &new_ite->thenBody(); for (auto expr : ite->thenBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } active_scope_ = &new_ite->elseBody(); for (auto expr : ite->elseBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } active_scope_ = prev_scope; - active_scope_expr_ = prev_scope_expr; } -void IndexLowering::visit(const kir::ForLoop* for_loop) { - const auto prev_scope_expr = active_scope_expr_; +void IndexLowering::handle(const kir::ForLoop* for_loop) { const auto prev_scope = active_scope_; - auto new_for_loop = ir_builder_.create(for_loop); + auto new_for_loop = IrBuilder::create(for_loop); pushBack(new_for_loop); - active_scope_expr_ = new_for_loop; active_scope_ = &new_for_loop->body(); + for_loops_.push_back(new_for_loop); for (auto expr : for_loop->body().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } + for_loops_.pop_back(); active_scope_ = prev_scope; - active_scope_expr_ = prev_scope_expr; } -void IndexLowering::visit(const kir::UnaryOp* uop) { +void IndexLowering::handle(const UnaryOp* uop) { const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); - pushBack(ir_builder_.create(uop->operation(), out, in)); + pushBack(IrBuilder::create(uop->getUnaryOpType(), out, in)); } -void IndexLowering::visit(const kir::BinaryOp* bop) { +void IndexLowering::handle(const BinaryOp* bop) { const auto lhs = lowerSrcIndex(bop->lhs(), bop->out()); const auto rhs = lowerSrcIndex(bop->rhs(), bop->out()); const auto out = lowerDstIndex(bop->out()); - pushBack(ir_builder_.create(bop->operation(), out, lhs, rhs)); + pushBack(IrBuilder::create(bop->getBinaryOpType(), out, lhs, rhs)); } -void IndexLowering::visit(const kir::TernaryOp* top) { +void IndexLowering::handle(const TernaryOp* top) { const auto in1 = lowerSrcIndex(top->in1(), top->out()); const auto in2 = lowerSrcIndex(top->in2(), top->out()); const auto in3 = lowerSrcIndex(top->in3(), top->out()); const auto out = lowerDstIndex(top->out()); - pushBack( - ir_builder_.create(top->operation(), out, in1, in2, in3)); + pushBack(IrBuilder::create( + top->getTernaryOpType(), out, in1, in2, in3)); } namespace { // Get the size of the temporary work buffer for grid communication, this can be // grid reduction, broadcast, or grid welford. -kir::Val* getGridCommWorkBufferSize( - kir::IrBuilder& ir_builder, - const kir::TensorDomain* td) { +Val* getGridCommWorkBufferSize(const TensorDomain* td) { // The buffer size is the number of thread blocks multiplied by the // number of threads not used for reduction domains. // Note: Previously it was calculated based on the shape of the @@ -125,7 +111,7 @@ kir::Val* getGridCommWorkBufferSize( // size if the parallel dimensions are exact, but otherwise, just // computing the buffer size based on the tensor shape isn't // sufficient since there could be extra threads/blocks. - kir::Val* buffer_size = ir_builder.create(1); + Val* buffer_size = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeThreads) { auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); if (pt_dim == nullptr || pt_dim->isOneInt()) { @@ -133,33 +119,31 @@ kir::Val* getGridCommWorkBufferSize( } if (isParallelTypeThreadDim(pt) && std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) { - return out_id->parallelType() == pt && + return out_id->getParallelType() == pt && (out_id->isReduction() || out_id->isBroadcast()); })) { continue; } - buffer_size = ir_builder.mulExpr(buffer_size, pt_dim); + buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim); } return buffer_size; } -kir::Val* getGridSyncBufferSize( - kir::IrBuilder& ir_builder, - const kir::TensorDomain* td) { +Val* getGridSyncBufferSize(const TensorDomain* td) { // See the comment above for getGridCommWorkBufferSize. - kir::Val* buffer_size = ir_builder.create(1); + Val* buffer_size = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeBIDs) { auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); if (pt_dim == nullptr || pt_dim->isOneInt()) { continue; } if (std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) { - return out_id->parallelType() == pt && + return out_id->getParallelType() == pt && (out_id->isReduction() || out_id->isBroadcast()); })) { continue; } - buffer_size = ir_builder.mulExpr(buffer_size, pt_dim); + buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim); } return buffer_size; } @@ -167,26 +151,25 @@ kir::Val* getGridSyncBufferSize( // Allocate global buffer for a grid communication calls, i.e. grid reduce, grid // welford reduce, grid broadcast. kir::Allocate* allocGlobalBufferForGridComm( - kir::IrBuilder& ir_builder, - kir::Val* buffer_size, + Val* buffer_size, DataType dtype, bool zero_init) { - const std::vector new_buffer_ids = { - ir_builder.create(ir_builder.zeroVal(), buffer_size)}; - const auto buffer_domain = - ir_builder.create(new_buffer_ids); - const auto buffer_tv = ir_builder.create( - dtype, buffer_domain, MemoryType::Global); - return ir_builder.create( - buffer_tv, buffer_tv->memoryType(), nullptr, zero_init); + const std::vector new_buffer_ids = { + IrBuilder::create( + GpuLower::current()->kernel()->zeroVal(), buffer_size)}; + const auto buffer_domain = IrBuilder::create(new_buffer_ids); + const auto buffer_tv = + IrBuilder::create(buffer_domain, dtype, MemoryType::Global); + return IrBuilder::create( + buffer_tv, buffer_tv->getMemoryType(), nullptr, zero_init); } } // namespace -void IndexLowering::visit(const kir::ReductionOp* rop) { - TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(rop)); +void IndexLowering::handle(const ReductionOp* rop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(rop)); - const auto out_tv = rop->out()->as(); + const auto out_tv = rop->out()->as(); const auto out_domain = out_tv->domain(); const bool is_block_reduce = out_domain->hasBlockReduction(); @@ -199,7 +182,7 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { std::none_of( out_domain->domain().begin(), out_domain->domain().end(), - [](kir::IterDomain* id) { + [](IterDomain* id) { return !id->isThread() && id->isReduction() && !id->extent()->isOneInt(); }), @@ -212,11 +195,11 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { const auto out = lowerDstIndex(rop->out()); const auto in = lowerSrcIndex(rop->in(), rop->out()); - kir::ReductionOp* block_reduction_op = nullptr; + ReductionOp* block_reduction_op = nullptr; if (is_block_reduce) { - block_reduction_op = ir_builder_.create( - rop->operation(), rop->init(), out, in); + block_reduction_op = IrBuilder::create( + rop->getReductionOpType(), rop->init(), out, in); if (rop->predicate()) { block_reduction_op->setPredicate(rop->predicate()); } @@ -228,29 +211,22 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { if (is_grid_reduce) { const auto reduce_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridCommWorkBufferSize(ir_builder_, out_domain), - out->dtype(), - false); + getGridCommWorkBufferSize(out_domain), out->dtype(), false); const auto sync_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridSyncBufferSize(ir_builder_, out_domain), - DataType::Int, - true); + getGridSyncBufferSize(out_domain), DataType::Int, true); const auto grid_reduction_op = (block_reduction_op == nullptr) - ? ir_builder_.create( - rop->operation(), rop->init(), out, in) + ? IrBuilder::create( + rop->getReductionOpType(), rop->init(), out, in) : block_reduction_op; // The thread predicate for GridReduction needs to be set // separately from the main predicate. Do not combine them like // other expressions. const auto& thread_pred = - GpuLower::current()->threadPredMap().getPredicatedParallelTypes( - out_tv->fuserTv()); - auto grid_reduction = ir_builder_.create( + GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); + auto grid_reduction = IrBuilder::create( grid_reduction_op, reduce_buffer, sync_buffer); grid_reduction->setThreadPredicate(thread_pred); @@ -260,8 +236,8 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { // predicate does not work when the write predicate of the // blockReduce is different from the read predicate. if (is_block_reduce) { - grid_reduction->setPredicate( - ir_builder_.create(ir_builder_.trueVal())); + grid_reduction->setPredicate(IrBuilder::create( + GpuLower::current()->kernel()->trueVal())); } else { grid_reduction->setPredicate(rop->predicate()); } @@ -277,15 +253,15 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { } if (!is_block_reduce && !is_grid_reduce) { - // TODO(kir): this breaks our "SSA" form - pushBack(ir_builder_.create(rop->operation(), out, out, in)); + pushBack( + IrBuilder::create(rop->getReductionOpType(), out, out, in)); } } -void IndexLowering::visit(const kir::WelfordOp* wop) { - TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(wop)); +void IndexLowering::handle(const WelfordOp* wop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(wop)); - const auto out_tv = wop->outAvg()->as(); + const auto out_tv = wop->outAvg()->as(); const auto out_domain = out_tv->domain(); const bool is_block_reduce = out_domain->hasBlockReduction(); @@ -298,7 +274,7 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { std::none_of( out_domain->domain().begin(), out_domain->domain().end(), - [](kir::IterDomain* id) { + [](IterDomain* id) { return !id->isThread() && id->isReduction(); }), "Found a reduction stage that has both a non-parallelized ", @@ -322,18 +298,18 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { auto out_var = lowerDstIndex(wop->outVar()); auto out_N = lowerDstIndex(wop->outN()); - kir::WelfordOp* welford_op = ir_builder_.create( - out_var, + WelfordOp* welford_op = IrBuilder::create( out_avg, + out_var, out_N, - wop->initVar(), wop->initAvg(), + wop->initVar(), wop->initN(), - in_var, in_avg, + in_var, in_N); - kir::WelfordOp* block_welford_op = nullptr; + WelfordOp* block_welford_op = nullptr; if (is_block_reduce) { block_welford_op = welford_op; @@ -348,21 +324,17 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { if (is_grid_reduce) { // Buffer allocation - const auto work_buffer_size = - getGridCommWorkBufferSize(ir_builder_, out_domain); + const auto work_buffer_size = getGridCommWorkBufferSize(out_domain); - const auto out_var_buffer = allocGlobalBufferForGridComm( - ir_builder_, work_buffer_size, out_var->dtype(), false); - const auto out_avg_buffer = allocGlobalBufferForGridComm( - ir_builder_, work_buffer_size, out_avg->dtype(), false); - const auto out_N_buffer = allocGlobalBufferForGridComm( - ir_builder_, work_buffer_size, out_N->dtype(), false); + const auto out_var_buffer = + allocGlobalBufferForGridComm(work_buffer_size, out_var->dtype(), false); + const auto out_avg_buffer = + allocGlobalBufferForGridComm(work_buffer_size, out_avg->dtype(), false); + const auto out_N_buffer = + allocGlobalBufferForGridComm(work_buffer_size, out_N->dtype(), false); const auto sync_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridSyncBufferSize(ir_builder_, out_domain), - DataType::Int, - true); + getGridSyncBufferSize(out_domain), DataType::Int, true); // Grid Welford instantiation const auto grid_welford_op = @@ -372,10 +344,9 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { // separately from the main predicate. Do not combine them like // other expressions. const auto& thread_pred = - GpuLower::current()->threadPredMap().getPredicatedParallelTypes( - out_tv->fuserTv()); + GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); - auto grid_welford = ir_builder_.create( + auto grid_welford = IrBuilder::create( grid_welford_op, out_var_buffer, out_avg_buffer, @@ -400,18 +371,18 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { } } -void IndexLowering::visit(const kir::BroadcastOp* bop) { - TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(bop)); +void IndexLowering::handle(const BroadcastOp* bop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(bop)); - const auto out_tv = bop->out()->as(); + const auto out_tv = bop->out()->as(); const auto out = lowerDstIndex(bop->out()); const auto in = lowerSrcIndex(bop->in(), bop->out()); - auto indexed_expr = ir_builder_.create(out, in); + auto indexed_expr = + IrBuilder::create(out, in, bop->getBroadcastDimFlags()); const ParallelTypeBitmap parallel_bitmap = - GpuLower::current()->threadPredMap().getParallelBroadcastDomains( - out_tv->fuserTv()); + GpuLower::current()->threadPredMap().getParallelBroadcastDomains(out_tv); const bool block_x = parallel_bitmap.get(ParallelType::BIDx); const bool block_y = parallel_bitmap.get(ParallelType::BIDy); @@ -430,18 +401,12 @@ void IndexLowering::visit(const kir::BroadcastOp* bop) { // Grid broadcast const auto out_domain = out_tv->domain(); const auto broadcast_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridCommWorkBufferSize(ir_builder_, out_domain), - out->dtype(), - false); + getGridCommWorkBufferSize(out_domain), out->dtype(), false); const auto sync_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridSyncBufferSize(ir_builder_, out_domain), - DataType::Int, - true); + getGridSyncBufferSize(out_domain), DataType::Int, true); - auto grid_broadcast = ir_builder_.create( + auto grid_broadcast = IrBuilder::create( indexed_expr, broadcast_buffer, sync_buffer); if (bop->predicate()) { @@ -453,19 +418,19 @@ void IndexLowering::visit(const kir::BroadcastOp* bop) { pushBack(grid_broadcast); } -void IndexLowering::visit(const kir::Allocate* allocate) { +void IndexLowering::handle(const kir::Allocate* allocate) { // TODO(kir): remove the need for const_cast pushBack(const_cast(allocate)); // NOLINT } -void IndexLowering::visit(const kir::Sync* sync) { +void IndexLowering::handle(const kir::Sync* sync) { // TODO(kir): remove the need for const_cast pushBack(const_cast(sync)); // NOLINT } -void IndexLowering::generate(const std::vector& exprs) { +void IndexLowering::generate(const std::vector& exprs) { for (auto expr : exprs) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 5eb27c78f28..2f3af0061e1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -1,10 +1,10 @@ #pragma once -#include +#include #include #include -#include +#include #include #include @@ -14,10 +14,11 @@ namespace jit { namespace fuser { namespace cuda { -class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor { +// TODO: Replace with mutator as IndexLowering is replacing expr's with +// versions that are doing indexing +class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { public: - static std::vector getIndexedExprs( - std::vector incoming_exprs) { + static std::vector getIndexedExprs(std::vector incoming_exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::IndexLowering::getIndexedExprs"); IndexLowering il; il.generate(incoming_exprs); @@ -25,28 +26,29 @@ class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor { } private: - IndexLowering(); + IndexLowering() = default; - void pushBack(kir::Expr*); + void pushBack(Expr*); - void visit(const kir::ForLoop*) final; - void visit(const kir::IfThenElse*) final; - void visit(const kir::UnaryOp*) final; - void visit(const kir::BinaryOp*) final; - void visit(const kir::TernaryOp*) final; - void visit(const kir::ReductionOp*) final; - void visit(const kir::WelfordOp*) final; - void visit(const kir::BroadcastOp*) final; - void visit(const kir::Allocate*) final; - void visit(const kir::Sync*) final; + void handle(const UnaryOp*) final; + void handle(const BinaryOp*) final; + void handle(const TernaryOp*) final; + void handle(const ReductionOp*) final; + void handle(const WelfordOp*) final; + void handle(const BroadcastOp*) final; - void generate(const std::vector& exprs); + void handle(const kir::ForLoop*) final; + void handle(const kir::IfThenElse*) final; + void handle(const kir::Allocate*) final; + void handle(const kir::Sync*) final; - kir::Val* lowerSrcIndex(kir::Val* val, kir::Val* dst) const; - kir::Val* lowerDstIndex(kir::Val* dst) const; + void generate(const std::vector& exprs); + + Val* lowerSrcIndex(Val* val, Val* dst) const; + Val* lowerDstIndex(Val* dst) const; private: - std::vector lowered_exprs_; + std::vector lowered_exprs_; // This is a slight work around as scope has a couple definitions, we have the // Scope that's in ForLoop/IfThenElse which is really just a wrapper around @@ -55,9 +57,10 @@ class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor { // could be either the body or else body of the IfThenElse. However, we want // to understand the nesting of IfThenElse/ForLoop nodes. kir::Scope* active_scope_ = nullptr; - kir::Expr* active_scope_expr_ = nullptr; - kir::IrBuilder ir_builder_; + // Track for loops to send to indexing. Similar to what's done in + // kir::IrVisitor + std::vector for_loops_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 0947ef0f579..77be88183ec 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -1,8 +1,8 @@ #include #include +#include #include -#include -#include +#include #include #include @@ -33,8 +33,8 @@ class SmemAllocMap { public: //! Insert a new node if it's a SMEM allocation void insert(kir::Allocate* alloc) { - if (auto tv = dynamic_cast(alloc->buffer())) { - if (tv->memoryType() == MemoryType::Shared) { + if (auto tv = dynamic_cast(alloc->buffer())) { + if (tv->getMemoryType() == MemoryType::Shared) { // Note that a TensorView can have two allocations due to // unswitch. auto p = map_.insert({tv, alloc}); @@ -50,290 +50,298 @@ class SmemAllocMap { } } - //! Get the buffer that is actually allocated for a given TV - kir::TensorView* getRealBuffer(kir::TensorView* tv) const { + //! Run through aliases to get the buffer that is actually allocated for a + //! given TV + TensorView* getRealBuffer(TensorView* tv) const { auto it = map_.find(tv); TORCH_INTERNAL_ASSERT( - it != map_.end(), "Allocation not found for ", kir::toString(tv)); + it != map_.end(), "Allocation not found for ", tv->toString()); const kir::Allocate* alloc = it->second; while (alloc->alias()) { alloc = alloc->alias(); } auto buf = alloc->buffer(); - TORCH_INTERNAL_ASSERT(buf->isA()); - return buf->as(); + TORCH_INTERNAL_ASSERT(buf->isA()); + return buf->as(); } private: - std::unordered_map map_; + std::unordered_map map_; }; -//! Insert WAR sync for a given ForLoop -class LocalSyncInserterForLoop { - using TvSet = std::unordered_set; +struct WarMemoryInfo { + // True if there's a sync after the last read within the alloc loop. + bool sync_after_read = false; - public: - //! Insert Sync nodes at the end of a given for-loop when a WAR - //! hazard may happen. - LocalSyncInserterForLoop(kir::ForLoop* fl, SmemAllocMap& alloc_map) - : alloc_map_(alloc_map) { - for (auto expr : fl->body().exprs()) { - handle(expr); - } + // True if there's a sync before the first write. There can be multiple writes + // from memory aliasing. + bool sync_before_write = false; - // No need to insert sync when the loop is not actually generated - if (fl->iter_domain()->isThread() || fl->iter_domain()->isBroadcast()) { - return; - } - - // Determine if any smem TV is written to at beginning of the for-loop - // and whether that smem TV is read from at the end of the for-loop - // Insert new SyncThreads at end of for-loop to prevent WAR race condition - // - // TODO: replace __syncthreads with __threadfence for alias ops - // - if (detectIntersection(initial_, final_) && - !fl->body().exprs().back()->isA() && !is_last_op_sync_) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - fl->body().push_back(ir_builder.create(true)); - initial_sync_ = true; - is_last_op_sync_ = true; - final_.clear(); - } - } + // Has there been a read of this memory location + bool read_hit = false; - const auto& initial() const { - return initial_; - } + // Has there been *the* write to this memory location, assumes single write + // instruction (needs to be before conditionals added to code) + bool write_hit = false; - const auto& final() const { - return final_; - } + // For loop this TV is compute_at'ed in. + kir::ForLoop* ca_loop = nullptr; +}; - const auto& all_smem_inputs() const { - return all_smem_inputs_; +// To prevent shared memory from being over written before it is read, a +// synchronization point has to be inserted either between the allocation of an +// SMEM buffer and where we write into it, or after the buffer's last read +// before exiting the allocation's scope. +// +// e.g. +// for i: +// "alloc A" in shared memory - This is really marked by the compute_at point +// sync_loc_0 +// for j: +// sync_loc_1 +// for k: +// sync_loc_2 +// A = ... +// for k: +// ... = ... A +// for j: +// for k: +// ... = ... A +// sync_loc_3 +// sync_loc_4 +// sync_loc_5 +// +// All sync locations here provide valid protection that memory in A is finished +// being read before it is over written in the next iteration +// +// Insertion of sync threads will be done from the inner most position to the +// outer most. If a sync protecting the buffer is not already placed, the +// location prefered for the sync threads is the last possible position. One +// future optimization could be to not sync on the last iteration of the loop +// the sync is placed in. +class WarSyncInserter : private kir::ExprMutator { + public: + static std::vector insert(const std::vector& exprs) { + WarSyncInserter inserter(exprs); + return inserter.exprs_; } - const auto& all_smem_outputs() const { - return all_smem_outputs_; + private: + //! Insert Sync nodes at the end of a given for-loop when a WAR + //! hazard may happen. + WarSyncInserter(const std::vector& exprs) { + auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); + for (const auto& entry : lower_alloc_info_map) { + alloc_map_.insert(entry.first); + } + kir::ExprMutator::traverseAndInsert(exprs); } - void handle(kir::Expr* expr) { - if (ir_utils::isTVOp(expr)) { - is_last_op_sync_ = false; - - // For this SyncInserter - if (initial_sync_) { - addInputSmemTvs(expr, final_); - } else { - addInputSmemTvs(expr, final_); - addOutputSmemTvs(expr, initial_); + void handle(kir::IfThenElse* ite) final { + TORCH_INTERNAL_ASSERT( + ite->elseBody().empty(), + "Pass does not support conditional flow,", + " needs to be done before conditional execution is lowered."); + kir::ExprMutator::handle(ite); + } + + void handle(kir::Sync* sync) final { + // Register the sync for the active for loop + sync_hit_.back() = true; + // Run through the active allocations, if a read was hit, register there was + // a sync after the read. If there's subsequent reads on this buffer the + // sync_after_read will be cleared. + for (auto& entry : smem_allocations_) { + auto& alloc_stack = entry.second; + if (alloc_stack.back().read_hit) { + alloc_stack.back().sync_after_read = true; } - - // For parent SyncInserter - addOutputSmemTvs(expr, all_smem_outputs_); - addInputSmemTvs(expr, all_smem_inputs_); - } else if (auto sync = dynamic_cast(expr)) { - handle(sync); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto alloc = dynamic_cast(expr)) { - alloc_map_.insert(alloc); } } - void handle(kir::Sync* sync) { - is_last_op_sync_ = true; - initial_sync_ = true; - final_.clear(); - } - - void handle(kir::IfThenElse* ite) { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); + // Checks if fl or loops within it have hit a sync + bool syncWithin(kir::ForLoop* fl) { + // If outer most scope check the first sync_hit_ position + if (fl == nullptr) { + return sync_hit_[0]; } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); - } - } - void handle(kir::ForLoop* fl) { - LocalSyncInserterForLoop child_sync_inserter(fl, alloc_map_); - - const auto& child_inputs = child_sync_inserter.all_smem_inputs(); - const auto& child_outputs = child_sync_inserter.all_smem_outputs(); - const bool maybe_skipped = !fl->start()->isZeroInt() && - !isParallelTypeThread(fl->iter_domain()->parallelType()); - - // Default - Track all smem inputs / outputs - all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end()); - all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end()); - - // Propagate the last_op_sync flag from the child loop. If the - // child is deterministically executed at least once, just set the - // flag with the child flag. Otherwise, conservatively set the - // flag, i.e., if the current flag is true and the child flag is - // also true, we can say the last op is still sync. - if (!maybe_skipped) { - is_last_op_sync_ = child_sync_inserter.is_last_op_sync_; - } else { - is_last_op_sync_ = - is_last_op_sync_ && child_sync_inserter.is_last_op_sync_; - } + // Find the for loop we want to look within + auto fl_it = std::find(for_loops_.begin(), for_loops_.end(), fl); - // When the child is not guaranteed to have sync. - if (!child_sync_inserter.initial_sync_) { - // If no sync is yet found, add the child outputs to - // initial. - if (!initial_sync_) { - initial_.insert(child_outputs.begin(), child_outputs.end()); - } - // Add the child inputs to final even when inital_sync is false, - // which only means sync may not be found yet. - final_.insert(child_inputs.begin(), child_inputs.end()); - } else { - // Similar to the above case, but here, the child is guaranteed - // to have sync, so we only need to look at initial and final. - if (!initial_sync_) { - initial_.insert( - child_sync_inserter.initial().begin(), - child_sync_inserter.initial().end()); - } - if (!maybe_skipped) { - initial_sync_ = true; - final_.clear(); - } - final_.insert( - child_sync_inserter.final().begin(), - child_sync_inserter.final().end()); - } - } + // Convert it to an index, but add one for the outer most scope + auto fl_i = std::distance(for_loops_.begin(), fl_it) + 1; - static bool detectIntersection(const TvSet& left, const TvSet& right) { - for (auto item : left) { - if (right.find(item) != right.end()) { + // Start at that index and see if there's syncs within that for loop + for (auto i : c10::irange(fl_i, sync_hit_.size())) { + if (sync_hit_[i]) { return true; } } return false; } - void addOutputSmemTvs(const kir::Expr* expr, TvSet& set) { - for (auto out : expr->outputs()) { - if (auto tv = dynamic_cast(out)) { - if (tv->memoryType() == MemoryType::Shared) { - auto real_tv = alloc_map_.getRealBuffer(tv); - set.insert(real_tv); - } + void handle(Expr* expr) final { + // If not a tensor view expression continue with dispatch + if (!ir_utils::isTvOp(expr)) { + kir::ExprMutator::handle(expr); + return; + } + + // Mark write has been hit for all output tvs + auto out_tvs = ir_utils::filterByType(expr->outputs()); + for (auto out_tv : out_tvs) { + if (out_tv->getMemoryType() != MemoryType::Shared) { + continue; } + auto& entry = getMemInfo(out_tv); + + // If this is the first write and there's a sync in one of the loops after + // the compute at loop, then this buffer is protected. + if (syncWithin(entry.ca_loop) && !entry.write_hit) { + entry.sync_before_write = true; + } + entry.write_hit = true; } - } - void addInputSmemTvs(const kir::Expr* expr, TvSet& set) { - for (auto in : expr->inputs()) { - if (auto tv = dynamic_cast(in)) { - if (tv->memoryType() == MemoryType::Shared) { - auto real_tv = alloc_map_.getRealBuffer(tv); - set.insert(real_tv); - } + // Mark read was hit, if sync_after_read was set, clear it. + auto inp_tvs = ir_utils::filterByType(expr->inputs()); + for (auto inp_tv : inp_tvs) { + if (inp_tv->getMemoryType() != MemoryType::Shared) { + continue; } + auto& entry = getMemInfo(inp_tv); + entry.read_hit = true; + // Clear the sync_after_read if it was set because there was another write + entry.sync_after_read = false; } } - private: - //! Allocation map of SMEM buffers - SmemAllocMap& alloc_map_; - - //! Track Shared Memory Inputs (Reads) for parent for-loop - TvSet all_smem_inputs_; - - //! Track Shared Memory Outputs (Writes) for parent for-loop - TvSet all_smem_outputs_; - - //! Shared Memory Writes at beginning of the for-loop - //! before first SyncThreads - TvSet initial_; + void handle(kir::ForLoop* for_loop) final { + // Push loop scope information + auto prev_within_iter_loop_ = within_iter_loop_; + sync_hit_.push_back(false); - //! Shared Memory Reads at end of the for-loop - //! Cleared after each SyncThreads - TvSet final_; + // If there is no real iterating loop WAR syncs aren't necessary + within_iter_loop_ = within_iter_loop_ || + !(for_loop->iter_domain()->isThread() || + for_loop->iter_domain()->isBroadcast() || + for_loop->iter_domain()->extent()->isOneInt()); - //! Track first sync deterministically found in for-loop. Even when a - //! child loop has a sync, if it may not be executed due to non-zero - //! start value, this flag remains false. - bool initial_sync_ = false; + // Process the expressions in the for loop + kir::ExprMutator::handle(for_loop); - //! Track if last op is sync - bool is_last_op_sync_ = false; -}; - -class LocalSyncInserter { - public: - //! Write-After-Read race conditions are only found within for-loops. - //! Sync nodes are inserted directly into the for-loops. - //! The expressions are modified in-place and exprs is const. - static void insertSyncs(const std::vector& exprs) { - LocalSyncInserter inserter; - inserter.insert(exprs); - } + // Sync analysis and cleanup: + // + // Pop for loop stack inside WarMemoryInfo structs if they match this one. + // Erase empty entries so we don't continue to search over them + // + // Insert sync at end of this for loop if any of the entries require + std::vector to_erase; + bool insert_sync = false; + for (auto& entry : smem_allocations_) { + auto& alloc_stack = entry.second; + if (alloc_stack.size() && alloc_stack.back().ca_loop == for_loop) { + if (!alloc_stack.back().sync_after_read && + !alloc_stack.back().sync_before_write) { + insert_sync = within_iter_loop_; + } - private: - void insert(const std::vector& exprs) { - for (auto expr : exprs) { - if (auto fl = dynamic_cast(expr)) { - LocalSyncInserterForLoop sync_inserter(fl, alloc_map_); - } else if (auto ite = dynamic_cast(expr)) { - insert(ite->thenBody().exprs()); - insert(ite->elseBody().exprs()); - } else if (auto alloc = dynamic_cast(expr)) { - alloc_map_.insert(alloc); + alloc_stack.pop_back(); + if (alloc_stack.empty()) { + to_erase.push_back(entry.first); + } } } - } - private: + for (auto tv : to_erase) { + smem_allocations_.erase(tv); + } + + // WAR Sync is necessary in this loop, register its insertion. + if (insert_sync) { + auto sync_expr = IrBuilder::create(true); + kir::ExprMutator::registerInsertAfter( + for_loop->body().exprs().back(), sync_expr, &for_loop->body()); + handle(sync_expr); + } + + // Pop for loop scope information + sync_hit_.pop_back(); + within_iter_loop_ = prev_within_iter_loop_; + } + + // Create a new WarMemoryInfo entry if required and return a reference to it, + // else return the WarMemoryInfo associated with tv + WarMemoryInfo& getMemInfo(TensorView* tv) { + auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv); + auto alloc_it = smem_allocations_.find(maybe_aliased_tv); + auto ca_loop = + loop_utils::getAllocInformation(tv, for_loops_).init_for_loop; + if (alloc_it == smem_allocations_.end()) { + WarMemoryInfo mem_info; + mem_info.ca_loop = ca_loop; + auto entry_it = + smem_allocations_ + .insert(std::make_pair( + maybe_aliased_tv, std::vector({mem_info}))) + .first; + return entry_it->second.back(); + } else if ( + maybe_aliased_tv != tv && alloc_it->second.back().ca_loop != ca_loop) { + WarMemoryInfo mem_info; + mem_info.ca_loop = ca_loop; + auto& alloc_stack = alloc_it->second; + alloc_stack.push_back(mem_info); + return alloc_stack.back(); + } + return alloc_it->second.back(); + } + + //! Allocation map of SMEM buffers. Needed because of SMEM buffer aliasing, + //! need to track the root of the alias to properly insert WAR hazard syncs SmemAllocMap alloc_map_; + + //! Is there a loop nest that has a non-trivial iteration (extent != 1) and + //! not bound to a block/thread. This indicates if a WAR sync is necessary, + //! otherwise the Expr is not in an iterating for loop. + bool within_iter_loop_ = false; + + // Track which loops have hit a sync. Used to see if there's a sync before + // write. + std::vector sync_hit_ = {false}; + + // Keep track of the active allocations we need to protect. Key is the + // "getRealBuffer", not the raw tv. There can be multiple WarMemoryInfo's + // because of aliasing. If the "getRealBuffer" tv has a compute at outside the + // alias tv, each aliased tv in a unique ca_loop has to be tracked separately + // for WAR insertion. + std::unordered_map> smem_allocations_; }; class ExprFlattener : private kir::IrVisitor { private: - void handle(kir::Expr* expr) { + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { if (expr->isA() || expr->isA()) { - expr->accept(this); + kir::IrVisitor::handle(expr); } else { - exprs_.push_back(expr); - } - } - - void visit(const kir::ForLoop* fl) final { - for (auto expr : fl->body().exprs()) { - handle(expr); - } - } - - void visit(const kir::IfThenElse* ite) final { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); + flat_exprs_.push_back(expr); } } private: - std::vector exprs_; + std::vector flat_exprs_; public: //! Flattens scopes extracting out a single ordered list of exprs. - static std::vector flatten( - const std::vector& loop_nests) { + static std::vector flatten(const std::vector& loop_nests) { ExprFlattener flattener; for (auto expr : loop_nests) { flattener.handle(expr); } - return flattener.exprs_; + return flattener.flat_exprs_; } }; @@ -342,53 +350,42 @@ class ValidatePlacementAfterWrites : private kir::IrVisitor { //! Validate no expr in writes found under loop static void validate( kir::ForLoop* loop, - const std::unordered_set& writes) { + const std::unordered_set& writes) { ValidatePlacementAfterWrites validator(writes); validator.handle(loop); } private: - ValidatePlacementAfterWrites(const std::unordered_set& writes) + using kir::IrVisitor::handle; + + ValidatePlacementAfterWrites(const std::unordered_set& writes) : writes_(writes) {} - void handle(kir::Expr* expr) { + void handle(Expr* expr) final { if (expr->isA() || expr->isA()) { - expr->accept(this); + kir::IrVisitor::handle(expr); } else { TORCH_INTERNAL_ASSERT( writes_.find(expr) == writes_.end(), "Block sync must be placed after ", - kir::toString(expr)); - } - } - - void visit(const kir::ForLoop* fl) final { - for (auto expr : fl->body().exprs()) { - handle(expr); - } - } - - void visit(const kir::IfThenElse* ite) final { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); + expr->toString()); } } private: - const std::unordered_set& writes_; + const std::unordered_set& writes_; }; -class ReadAfterWriteSyncs : public kir::MutableIrVisitor { +class ReadAfterWriteSyncs : public kir::ExprMutator { private: + using kir::ExprMutator::handle; + //! Traverse up the loop stack from loops_it and if a halo loop is //! found, place a given sync expr before the outer-most halo loop. bool insertBeforeHaloLoop( std::vector::iterator loops_it, kir::Sync* sync_expr, - const std::unordered_set& writes) { + const std::unordered_set& writes) { std::vector::iterator halo_loop_it; bool halo_loop_found = false; @@ -420,21 +417,21 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { if (halo_loop_it == for_loops_.begin()) { // place in global scope - auto place_before_it = - std::find(loop_nests_.begin(), loop_nests_.end(), halo_loop); - TORCH_INTERNAL_ASSERT(place_before_it != loop_nests_.end()); - loop_nests_.insert(place_before_it, sync_expr); + auto place_before_it = std::find(exprs_.begin(), exprs_.end(), halo_loop); + TORCH_INTERNAL_ASSERT(place_before_it != exprs_.end()); + exprs_.insert(place_before_it, sync_expr); } else { auto place_in = *(halo_loop_it - 1); - place_in->body().insert_before(halo_loop, sync_expr); + kir::ExprMutator::registerInsertBefore( + halo_loop, sync_expr, &place_in->body()); } return true; } - void handle(kir::Expr* expr) { - if (!ir_utils::isTVOp(expr) || expr->isA()) { - expr->accept(this); + void handle(Expr* expr) final { + if (!ir_utils::isTvOp(expr) || expr->isA()) { + kir::ExprMutator::handle(expr); return; } @@ -443,8 +440,8 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { auto last_writes = last_writes_.front(); last_writes_.pop_front(); // Found that a sync is needed - TORCH_INTERNAL_ASSERT(expr->outputs()[0]->isA()); - auto out_tv = expr->outputs()[0]->as(); + TORCH_INTERNAL_ASSERT(expr->outputs()[0]->isA()); + auto out_tv = expr->outputs()[0]->as(); // Find where a sync needs to be inserted // This is very similar to how allocations are placed, simply place sync @@ -454,39 +451,35 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { // out of or saving state for tensor view ID -> for loop // TODO: Explicitly test the 3 cases below - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto sync_expr = ir_builder.create(); - if (out_tv->fuserTv()->getComputeAtPosition() == 0) { + auto sync_expr = IrBuilder::create(); + if (out_tv->getComputeAtPosition() == 0) { // Sync should be placed at global scope, after its outer most loop if // it has one. - kir::Expr* place_after = for_loops_.size() > 0 ? for_loops_[0] : expr; - // Find location in loop_nests_ + Expr* place_after = for_loops_.size() > 0 ? for_loops_[0] : expr; + // Find location in exprs_ auto place_after_it = - std::find(loop_nests_.begin(), loop_nests_.end(), place_after); + std::find(exprs_.begin(), exprs_.end(), place_after); TORCH_INTERNAL_ASSERT( - place_after_it != loop_nests_.end(), + place_after_it != exprs_.end(), "Could not figure out where to place synchronization. ", "Tried to place after, ", - toString(place_after), + place_after->toString(), ", but could not find this expression at the global scope."); - loop_nests_.insert(place_after_it + 1, sync_expr); + + registerInsertAfter(*(place_after_it + 1), sync_expr, nullptr); } else { // Find the last loop in computeAt of out_tv, this is the loop where we // would place an allocation for out_tv - auto fuser_tv = out_tv->fuserTv(); - auto lowered_local_id = - GpuLower::current() - ->lowerValue(fuser_tv->axis( - (int)out_tv->fuserTv()->getComputeAtPosition() - 1)) - ->as(); + auto local_id = out_tv->axis((int)out_tv->getComputeAtPosition() - 1); auto loops_it = std::find_if( for_loops_.begin(), for_loops_.end(), - [&lowered_local_id](const auto& loop) { + [&local_id](const auto& loop) { return GpuLower::current()->caLoopMap().areMapped( - loop->iter_domain(), lowered_local_id) || - loop->iter_domain()->parallelType() == ParallelType::Unroll; + loop->iter_domain(), local_id) || + loop->iter_domain()->getParallelType() == + ParallelType::Unroll; }); TORCH_INTERNAL_ASSERT(loops_it != for_loops_.end()); @@ -497,7 +490,7 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { } auto place_in = *loops_it; - kir::Expr* place_after = nullptr; + Expr* place_after = nullptr; if (loops_it + 1 == for_loops_.end()) { // Inline allocation, place after expr @@ -509,22 +502,12 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { place_after = *(loops_it + 1); } - place_in->body().insert_after(place_after, sync_expr); + registerInsertAfter(place_after, sync_expr, &place_in->body()); } } } - void visit(kir::ForLoop* fl) final { - for_loops_.push_back(fl); - // Modifying in place, make a copy of the vector - const std::vector exprs = fl->body().exprs(); - for (auto expr : exprs) { - handle(expr); - } - for_loops_.pop_back(); - } - - void visit(kir::IfThenElse*) final { + void handle(kir::IfThenElse*) final { TORCH_INTERNAL_ASSERT( false, "Pass does not support conditional statements, ", @@ -532,18 +515,17 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { } // Clear the modify status for all shared memory buffers - static void cleanSharedMemory( - std::unordered_map& smem) { + static void cleanSharedMemory(std::unordered_map& smem) { smem.clear(); } // Return a set of expressions that modify shared-memory // tensors. Expressions are excluded when syncthreads are already // placed. - std::unordered_set isModifiedSharedMemory( - const std::unordered_map& smem, - const std::vector& tvs) const { - std::unordered_set last_writes; + std::unordered_set isModifiedSharedMemory( + const std::unordered_map& smem, + const std::vector& tvs) const { + std::unordered_set last_writes; for (auto tv : tvs) { auto it = smem.find(tv); if (it != smem.end()) { @@ -553,18 +535,17 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { return last_writes; } - ReadAfterWriteSyncs(std::vector _loop_nests) - : loop_nests_(std::move(_loop_nests)) { + ReadAfterWriteSyncs(const std::vector& _exprs) { // Fusion shared_memory values // Tracks if shared memory is modified - std::unordered_map smem; + std::unordered_map smem; // Flatten all the expressions - auto flattened_exprs = ExprFlattener::flatten(loop_nests_); + auto flattened_exprs = ExprFlattener::flatten(_exprs); - kir::Expr* prev_tv_expr = nullptr; + Expr* prev_tv_expr = nullptr; for (auto expr : flattened_exprs) { - if (!ir_utils::isTVOp(expr) || expr->isA()) { + if (!ir_utils::isTvOp(expr) || expr->isA()) { continue; } @@ -578,22 +559,20 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { cleanSharedMemory(smem); } - for (auto out : expr->outputs()) { - if (out->isA()) { - if (out->as()->memoryType() == MemoryType::Shared) { - smem[out] = expr; - } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + // Double buffered tensors do not need RAW sync to be inserted + // here, except for the initial load part, which is taken care + // separately by DoubleBufferInserter. + if (tv->getMemoryType() == MemoryType::Shared && + !tv->isDoubleBuffered()) { + smem[tv] = expr; } } prev_tv_expr = expr; } - // Insert read after write syncs - const std::vector exprs = loop_nests_; - for (auto expr : exprs) { - handle(expr); - } + kir::ExprMutator::traverseAndInsert(_exprs); TORCH_INTERNAL_ASSERT( sync_after_.empty(), "Didn't place all required syncs."); @@ -601,7 +580,7 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { private: //! Keep track of expressions that must be followed by syncthreads - std::deque sync_after_; + std::deque sync_after_; //! Keep track of write expressions that must be placed before //! syncthreads. @@ -611,35 +590,27 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { //! be placed before that. last_writes_ keeps track of expressions //! modifying the smem buffer each syncthreads is used for so that //! it is not placed before those write expressions. - std::deque> last_writes_; - - //! Keep track of for loops while inserting syncthreads - std::vector for_loops_; - - //! Loop-nests where syncthreads are inserted - std::vector loop_nests_; + std::deque> last_writes_; public: - static std::vector insert( - const std::vector& loop_nests) { + static std::vector insert(const std::vector& loop_nests) { ReadAfterWriteSyncs inserter(loop_nests); - return inserter.loop_nests_; + return inserter.exprs_; } }; } // namespace -std::vector insertRawThreadSynchronization( - const std::vector& exprs) { +std::vector insertRawThreadSynchronization( + const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertRawThreadSynchronization"); return ReadAfterWriteSyncs::insert(exprs); } -std::vector insertWarThreadSynchronization( - const std::vector& exprs) { +std::vector insertWarThreadSynchronization( + const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertWarThreadSynchronization"); - LocalSyncInserter::insertSyncs(exprs); - return exprs; + return WarSyncInserter::insert(exprs); } } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h index 50618373448..756462f0bd7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -16,40 +16,14 @@ namespace cuda { //! //! WAR race condition occurs when the next iteration of the loop overwrites //! shared memory value before a previous operation has finished reading it. -//! -//! WAR Race Check: -//! Track all output shared memory TVs before first sync -//! Track all input shared memory TVs after last sync -//! If the intersection is non-empty, then there is a WAR race condition. -//! Recursively check each nested for-loop -//! -//! Parent-Child For-Loop Recursive Relationship -//! Notation: -//! None - Zero Syncs -//! 1+ - One or more Syncs -//! End - Sync is last op in for-loop to prevent WAR race condition -//! -//! Default: Track all shared memory inputs and outputs -//! -//! Parent - None -//! Child - None => Append All Child Outputs to Parent Initial -//! Child - 1+ => Parent first sync => Inherit Child Initial + Final -//! Child - End => Parent first sync => Keep Child Initial / Clear Parent Final -//! -//! Parent - 1+ -//! Child - None => Append All Child to Parent Last -//! Child - 1+ => Child Final to Parent Final / Discard Child Initial -//! Child - End => Clear Parent Last / Discard Child Initial -//! -//! If Child - End and Parent has zero remaining operations, then -//! Parent inherits Child End. -//! -std::vector insertWarThreadSynchronization( - const std::vector& exprs); +std::vector insertWarThreadSynchronization( + const std::vector& exprs); //! Insert syncs between writing to shared memory and then reading it. -std::vector insertRawThreadSynchronization( - const std::vector& exprs); +//! RAW pass is run before indexing, unrolling (loop duplication), memory +//! aliasing, and index (grid/block bcast/reduction) +std::vector insertRawThreadSynchronization( + const std::vector& exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index e4396f9a864..12c7d33e077 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -19,7 +18,7 @@ namespace jit { namespace fuser { namespace cuda { -std::vector LoopNestGenerator::loweredExprs( +std::vector LoopNestGenerator::loweredExprs( const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::LoopNestGenerator::loweredExprs"); TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); @@ -33,22 +32,20 @@ LoopNestGenerator::LoopNestGenerator(const std::vector& exprs) { namespace { -kir::ForLoop* openForHelper(kir::ForLoop* scope, kir::IterDomain* kir_id) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto extent_with_halo = gpu_lower->haloInfo().getExtent(kir_id); +kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { + auto extent_with_halo = GpuLower::current()->haloInfo().getExtent(id); kir::ForLoop* new_scope = nullptr; if (extent_with_halo) { // When an axis is extended with halo, unrolling and vectorization // are assumed to not be used for now. TORCH_INTERNAL_ASSERT( - kir_id->parallelType() != ParallelType::Unroll && - !isParallelTypeVectorize(kir_id->parallelType())); + id->getParallelType() != ParallelType::Unroll && + !isParallelTypeVectorize(id->getParallelType())); // Use the extent that's extended by halo - new_scope = ir_builder.create( - kir_id, - kir_id->isBroadcast() ? ir_builder.zeroVal() - : ir_builder.create(c10::nullopt), + new_scope = IrBuilder::create( + id, + id->isBroadcast() ? GpuLower::current()->kernel()->zeroVal() + : IrBuilder::create(c10::nullopt), nullptr, extent_with_halo, nullptr, @@ -56,7 +53,7 @@ kir::ForLoop* openForHelper(kir::ForLoop* scope, kir::IterDomain* kir_id) { nullptr, false); } else { - new_scope = ir_builder.create(kir_id); + new_scope = IrBuilder::create(id); } if (scope != nullptr) { scope->body().insert(0, new_scope); @@ -66,13 +63,13 @@ kir::ForLoop* openForHelper(kir::ForLoop* scope, kir::IterDomain* kir_id) { } // namespace -void LoopNestGenerator::openFor(kir::IterDomain* kir_iter_domain) { +void LoopNestGenerator::openFor(IterDomain* id) { if (for_loops_.size() > 0) { - const auto new_scope = openForHelper(for_loops_.back(), kir_iter_domain); + const auto new_scope = openForHelper(for_loops_.back(), id); // for_loop_allocations_.insert({new_scope, 0}); for_loops_.push_back(new_scope); } else { - for_loops_.push_back(openForHelper(nullptr, kir_iter_domain)); + for_loops_.push_back(openForHelper(nullptr, id)); lowered_exprs_.insert(lowered_exprs_.begin(), for_loops_.back()); } } @@ -82,7 +79,7 @@ void LoopNestGenerator::closeFor() { for_loops_.pop_back(); } -void LoopNestGenerator::pushFront(kir::Expr* expr) { +void LoopNestGenerator::pushFront(Expr* expr) { if (for_loops_.size() == 0) { lowered_exprs_.insert(lowered_exprs_.begin(), expr); } else { @@ -91,18 +88,15 @@ void LoopNestGenerator::pushFront(kir::Expr* expr) { } void LoopNestGenerator::handle(Expr* expr) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - // Check if it's a tensor view expression we need to place in the loop nest // structure - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { // Close all the loops, scalar operations cannot be inside for loops based // on expr sorting. while (!for_loops_.empty()) { closeFor(); } - pushFront(gpu_lower->lowerExpr(expr)); + pushFront(expr); for (auto out : expr->outputs()) { TORCH_INTERNAL_ASSERT( @@ -112,10 +106,8 @@ void LoopNestGenerator::handle(Expr* expr) { " cannot lower ", out->getValType().value()); - pushFront(ir_builder.create( - gpu_lower->lowerValue(out), - MemoryType::Local, - ir_builder.create(1))); + pushFront(IrBuilder::create( + out, MemoryType::Local, GpuLower::current()->kernel()->oneVal())); } return; } @@ -130,27 +122,19 @@ void LoopNestGenerator::handle(Expr* expr) { // Figure out what the entire loop structure should look like. std::vector loop_structure = loop_structures_.at(out_tv); - std::vector kir_loop_structure; - std::transform( - loop_structure.begin(), - loop_structure.end(), - std::back_inserter(kir_loop_structure), - [&gpu_lower](IterDomain* id) { - return gpu_lower->lowerValue(id)->as(); - }); // Ordering of loop_structure is global, so simply close loops we don't need, // and open the ones we do. while (!for_loops_.empty() && std::find( - kir_loop_structure.begin(), - kir_loop_structure.end(), - for_loops_.back()->iter_domain()) == kir_loop_structure.end()) { + loop_structure.begin(), + loop_structure.end(), + for_loops_.back()->iter_domain()) == loop_structure.end()) { closeFor(); } - for (auto loop : kir_loop_structure) { + for (auto loop : loop_structure) { auto find_it = std::find_if( for_loops_.begin(), for_loops_.end(), [loop](kir::ForLoop* fl) { return fl->iter_domain() == loop; @@ -160,7 +144,7 @@ void LoopNestGenerator::handle(Expr* expr) { } } - pushFront(gpu_lower->lowerExpr(expr)); + pushFront(expr); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index fbbdf079e89..9b480d7eb6f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -1,13 +1,12 @@ #pragma once -#include +#include #include #include #include #include -#include #include namespace torch { @@ -30,20 +29,20 @@ namespace cuda { //! nests to initialize reduction buffers. class TORCH_CUDA_CU_API LoopNestGenerator { public: - static std::vector loweredExprs(const std::vector& exprs); + static std::vector loweredExprs(const std::vector& exprs); private: LoopNestGenerator(const std::vector& exprs); // Open a new inner most for loop, track which TV it was constructed from // according to the computeAt chain. - void openFor(kir::IterDomain*); + void openFor(IterDomain*); // Close the inner most for loop void closeFor(); // Appends an expression to the current scope - void pushFront(kir::Expr* expr); + void pushFront(Expr* expr); void handle(Expr* expr); @@ -52,7 +51,7 @@ class TORCH_CUDA_CU_API LoopNestGenerator { private: // Lowered exprs to return - std::vector lowered_exprs_; + std::vector lowered_exprs_; // Keep all for loops conveniently to make unrolling easier, basically just a // stack of the active for_loops diff --git a/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp b/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp index f5f5c72676a..f17f91806d6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include namespace torch { @@ -12,11 +12,11 @@ namespace cuda { namespace { -class MagicZeroInserter : public kir::MutableIrVisitor { +class MagicZeroInserter : public kir::ExprMutator { public: - static std::vector insert(const std::vector& exprs) { + static std::vector insert(const std::vector& exprs) { MagicZeroInserter inserter(exprs); - return inserter.loop_nests_; + return inserter.exprs_; } private: @@ -25,94 +25,43 @@ class MagicZeroInserter : public kir::MutableIrVisitor { kir::ForLoop* fl = nullptr; }; - MagicZeroInserter(const std::vector& exprs) - : loop_nests_(exprs), ir_builder(GpuLower::current()->kernel()) { - loop_nests_.insert( - loop_nests_.begin(), ir_builder.create()); - for (auto expr : exprs) { - handle(expr); - } - insertAll(); - } - - void handle(kir::Expr* expr) { - if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } - } - - void handle(kir::IfThenElse* ite) { - scope_nest_.push_back(&ite->thenBody()); - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - scope_nest_.pop_back(); - scope_nest_.push_back(&ite->elseBody()); - for (auto expr : ite->elseBody().exprs()) { - handle(expr); - } - scope_nest_.pop_back(); + MagicZeroInserter(const std::vector& exprs) { + TORCH_INTERNAL_ASSERT(exprs.size()); + kir::ExprMutator::registerInsertBefore( + exprs.front(), IrBuilder::create(), nullptr); + kir::ExprMutator::traverseAndInsert(exprs); } - void handle(kir::ForLoop* fl) { + void handle(kir::ForLoop* fl) final { if (fl->isUnrolled()) { - kir::Scope* scope = nullptr; - if (!scope_nest_.empty()) { - scope = scope_nest_.back(); - } - insertion_list_.push_back({scope, fl}); - } else { - scope_nest_.push_back(&fl->body()); - for (auto expr : fl->body().exprs()) { - handle(expr); - } - scope_nest_.pop_back(); - } - } - - void insertAll() { - for (const auto& info : insertion_list_) { - auto fl = info.fl; - auto scope = info.scope; - if (scope == nullptr) { - // place in global scope - auto loop_it = std::find(loop_nests_.begin(), loop_nests_.end(), fl); - TORCH_INTERNAL_ASSERT(loop_it != loop_nests_.end()); - // Place after the loop - loop_it++; - loop_nests_.insert(loop_it, ir_builder.create()); + if (scope_.empty()) { + kir::ExprMutator::registerInsertAfter( + fl, IrBuilder::create()); } else { - scope->insert_after(fl, ir_builder.create()); + TORCH_INTERNAL_ASSERT( + scope_.back()->exprs().size(), "Not expecting an empty loop."); + kir::ExprMutator::registerInsertAfter( + fl, IrBuilder::create(), scope_.back()); } + } else { + kir::ExprMutator::handle(fl); } } - //! Keep track for loop structure - std::vector scope_nest_; - - // Keep a copy of the expressions provided - std::vector loop_nests_; - - kir::IrBuilder ir_builder; - std::vector insertion_list_; }; } // namespace -std::vector insertMagicZero(const std::vector& exprs) { +std::vector insertMagicZero(const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertMagicZero"); // Check if magic zero was even used, if not we don't have to define it or // update it. const auto gpu_lower = GpuLower::current(); auto kernel = gpu_lower->kernel(); - const bool has_magic_zero = std::any_of( - kernel->irNodes().begin(), - kernel->irNodes().end(), - [](const std::unique_ptr& ir_node) { - return ir_node->isA() && isMagicZero(ir_node->as()); + const bool has_magic_zero = + std::any_of(kernel->vals().begin(), kernel->vals().end(), [](Val* val) { + return isMagicZero(val); }); if (!has_magic_zero) { @@ -122,19 +71,21 @@ std::vector insertMagicZero(const std::vector& exprs) { return MagicZeroInserter::insert(exprs); } -bool isMagicZero(kir::Val* val) { - auto ns = dynamic_cast(val); - if (ns == nullptr) { +bool isMagicZero(const Val* val) { + if (!val->isA()) { return false; } + auto ns = val->as(); return ns->dtype() == DataType::Int && ns->name() == std::string(kMagicZeroName); } -bool isProtectedWithMagicZero(kir::Val* val) { - auto def = dynamic_cast(val->definition()); - return def && def->operation() == BinaryOpType::Add && - isMagicZero(def->rhs()); +bool isProtectedWithMagicZero(const Val* val) { + if (val->definition() == nullptr || !val->definition()->isA()) { + return false; + } + auto bop = val->definition()->as(); + return bop->getBinaryOpType() == BinaryOpType::Add && isMagicZero(bop->rhs()); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_magic_zero.h b/torch/csrc/jit/codegen/cuda/lower_magic_zero.h index 03a37a46813..942a3302801 100644 --- a/torch/csrc/jit/codegen/cuda/lower_magic_zero.h +++ b/torch/csrc/jit/codegen/cuda/lower_magic_zero.h @@ -14,15 +14,15 @@ namespace cuda { //! zero update after every (outer most) loop nest with a compile time extent. //! //! This will make sure nvrtc does not aggressively save predicate and indices. -std::vector insertMagicZero(const std::vector& exprs); +std::vector insertMagicZero(const std::vector& exprs); //! Check if val is a reference to the magic zero variable -bool isMagicZero(kir::Val* val); +bool isMagicZero(const Val* val); //! Check if val is protected with magic zero. //! //! Specifically, this returns true if val is defined as "x + magic_zero". -bool isProtectedWithMagicZero(kir::Val* val); +bool isProtectedWithMagicZero(const Val* val); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp index b94c12c27c8..66b405ac8e2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp @@ -5,8 +5,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -18,85 +17,64 @@ namespace cuda { namespace { -class MisalignedVectorizationModifier { +class MisalignedVectorizationModifier : public kir::ExprMutator { public: - void process(const std::vector& exprs) { - FUSER_PERF_SCOPE( - "GpuLower::Lower::MisalignedVectorizationModifier::process"); - // Run through loop nests - // Find for-loops with misaligned vectorization domains - for (auto* expr : exprs) { - handle(expr); - } - } + MisalignedVectorizationModifier() = delete; - const std::unordered_map& replacementMap() const { - return expr_replacement_map_; + static std::vector processMisalignedVectorization( + const std::vector& exprs) { + FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization"); + MisalignedVectorizationModifier mvm(exprs); + return mvm.exprs_; } private: - void handle(kir::Expr* expr) { - if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } + MisalignedVectorizationModifier(const std::vector& exprs) { + FUSER_PERF_SCOPE("GpuLower::Lower::MisalignedVectorizationModifier"); + // Run through loop nests + // Find for-loops with misaligned vectorization domains + kir::ExprMutator::traverseAndInsert(exprs); } - void handle(kir::ForLoop* fl) { - for_loops_structure_.push_back(fl); - - // Make copy of exprs because we replace them inplace in fl - const auto exprs_copy = fl->body().exprs(); - + void handle(kir::ForLoop* fl) final { + kir::Scope* scope = scope_.empty() ? nullptr : scope_.back(); if (containsAnyDirectChildMisalignedVectorize(fl)) { - auto new_fl = handleMisalignedVectorize(for_loops_structure_, fl); - expr_replacement_map_.insert({fl, new_fl}); - } else { - for (auto expr : exprs_copy) { - handle(expr); - } - } + for_loops_.push_back(fl); + auto new_fl = handleMisalignedVectorize(for_loops_, fl); + for_loops_.pop_back(); - for_loops_structure_.pop_back(); - } - - void handle(kir::IfThenElse* ite) { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); + kir::ExprMutator::registerReplace(fl, new_fl, scope); + } else { + kir::ExprMutator::handle(fl); } } struct ReferenceTensors { // Input TensorView to Vectorize Set operation - kir::TensorView* in_tv = nullptr; + TensorView* in_tv = nullptr; // Output TensorView to Vectorize Set operation - kir::TensorView* out_tv = nullptr; + TensorView* out_tv = nullptr; // TensorView in global memory - kir::TensorView* global_tv = nullptr; + TensorView* global_tv = nullptr; // TensorView with vectorize IterDomain and not in global memory - kir::TensorView* vec_tv = nullptr; + TensorView* vec_tv = nullptr; }; - ReferenceTensors getReferenceTensors(kir::Expr* vectorized_expr) { + ReferenceTensors getReferenceTensors(Expr* vectorized_expr) { TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr); TORCH_INTERNAL_ASSERT( - vectorized_expr->outputs().front()->isA()); - TORCH_INTERNAL_ASSERT( - vectorized_expr->inputs().front()->isA()); + vectorized_expr->outputs().front()->isA()); + TORCH_INTERNAL_ASSERT(vectorized_expr->inputs().front()->isA()); - auto in_tv = vectorized_expr->inputs().front()->as(); - auto out_tv = vectorized_expr->outputs().front()->as(); + auto in_tv = vectorized_expr->inputs().front()->as(); + auto out_tv = vectorized_expr->outputs().front()->as(); const bool global_vectorize_write_op = - (out_tv->memoryType() == MemoryType::Global && - in_tv->memoryType() == MemoryType::Local); + (out_tv->getMemoryType() == MemoryType::Global && + in_tv->getMemoryType() == MemoryType::Local); const bool global_vectorize_read_op = - (out_tv->memoryType() == MemoryType::Local && - in_tv->memoryType() == MemoryType::Global); + (out_tv->getMemoryType() == MemoryType::Local && + in_tv->getMemoryType() == MemoryType::Global); TORCH_INTERNAL_ASSERT( global_vectorize_write_op || global_vectorize_read_op, "Unsupported vectorize memory configuration detected."); @@ -104,25 +82,26 @@ class MisalignedVectorizationModifier { // TensorView on global memory. This is the tensor that may have // a non-aligned base address. auto global_tv = - (out_tv->memoryType() == MemoryType::Global) ? out_tv : in_tv; + (out_tv->getMemoryType() == MemoryType::Global) ? out_tv : in_tv; // TensorView with the misaligned vec iterDomain. It is the consumer // of vectorized load or the producer of vectorized store. It is // assumed that when the output TV is not on global memory, this // expression is a vectorized load, so the output TV is vec_tv. - auto vec_tv = (out_tv->memoryType() != MemoryType::Global) ? out_tv : in_tv; + auto vec_tv = + (out_tv->getMemoryType() != MemoryType::Global) ? out_tv : in_tv; return {in_tv, out_tv, global_tv, vec_tv}; } struct VectorizeData { - kir::Val* vector_size = nullptr; - kir::Val* shift = nullptr; - kir::Val* extent = nullptr; - kir::Val* remainder = nullptr; - kir::Val* extent_minus_remainder = nullptr; - kir::Val* last_root_domain_index = nullptr; - kir::Val* last_root_domain_index_shift = nullptr; + Val* vector_size = nullptr; + Val* shift = nullptr; + Val* extent = nullptr; + Val* remainder = nullptr; + Val* extent_minus_remainder = nullptr; + Val* last_root_domain_index = nullptr; + Val* last_root_domain_index_shift = nullptr; }; // Create constants for handling misaligned addresses @@ -130,48 +109,43 @@ class MisalignedVectorizationModifier { const std::vector& for_loop_structure, const ReferenceTensors& tensors, kir::IfThenElse* parent_scope_ite) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - // Generate vectorize index - auto indices = (tensors.out_tv->memoryType() == MemoryType::Global) - ? Index::getConsumerStridedIndices( - tensors.out_tv->fuserTv(), for_loop_structure) + auto indices = (tensors.out_tv->getMemoryType() == MemoryType::Global) + ? Index::getConsumerStridedIndices(tensors.out_tv, for_loop_structure) : Index::getProducerStridedIndices( - tensors.in_tv->fuserTv(), - tensors.out_tv->fuserTv(), - for_loop_structure); + tensors.in_tv, tensors.out_tv, for_loop_structure); // >>>>>>>>>>>>> // Number of elements in vectorize access auto vector_size = - tensors.vec_tv->domain()->domain().back()->extent()->as(); + tensors.vec_tv->domain()->domain().back()->extent()->as(); // Size of memory type for the elements - kir::Int* data_size_in_bytes = - ir_builder.create(dataTypeSize(tensors.vec_tv->dtype())); + Int* data_size_in_bytes = + IrBuilder::create(dataTypeSize(tensors.vec_tv->dtype())); // The number of bytes in the vectorize access auto vector_size_in_bytes = - ir_builder.mulExpr(vector_size, data_size_in_bytes); + IrBuilder::mulExpr(vector_size, data_size_in_bytes); - auto index = ir_builder.create( - tensors.global_tv->fuserTv(), indices); + auto index = + IrBuilder::create(tensors.global_tv, indices); auto address = createNamedScalarFromValue( parent_scope_ite->thenBody(), index, "address", true); // offset_size = (address % vector_size_bytes) / data_type_size_bytes // shift_init = vector_size - offset_size - auto a = ir_builder.modExpr(address, vector_size_in_bytes); - auto b = ir_builder.divExpr(a, data_size_in_bytes); - auto c = ir_builder.subExpr(vector_size, b); + auto a = IrBuilder::modExpr(address, vector_size_in_bytes); + auto b = IrBuilder::divExpr(a, data_size_in_bytes); + auto c = IrBuilder::subExpr(vector_size, b); auto shift_init = createNamedScalarFromValue( parent_scope_ite->thenBody(), c, "shift_val"); // shift = (shift_init == vector_size) ? 0 : shift_init // The number of elements until the first aligned address - auto shift_pred = ir_builder.eqExpr(shift_init, vector_size); - auto shift_val = - ir_builder.whereExpr(shift_pred, ir_builder.zeroVal(), shift_init); + auto shift_pred = IrBuilder::eqExpr(shift_init, vector_size); + auto shift_val = IrBuilder::whereExpr( + shift_pred, GpuLower::current()->kernel()->zeroVal(), shift_init); // >>>>>>>>>>>>> auto shift = createNamedScalarFromValue( @@ -183,13 +157,13 @@ class MisalignedVectorizationModifier { // remainder = (extent - shift) % vector_size // The number of elements remaining not accessed by vectorized operations - auto remaining_extent = ir_builder.subExpr(extent, shift); - auto remainder_val = ir_builder.modExpr(remaining_extent, vector_size); + auto remaining_extent = IrBuilder::subExpr(extent, shift); + auto remainder_val = IrBuilder::modExpr(remaining_extent, vector_size); auto remainder = createNamedScalarFromValue( parent_scope_ite->thenBody(), remainder_val, "remainder"); // (extent - remainder) is the upper-bound for the vectorize section - auto extent_remainder_val = ir_builder.subExpr(extent, remainder); + auto extent_remainder_val = IrBuilder::subExpr(extent, remainder); // >>>>>>>>>>>>> auto extent_minus_remainder = createNamedScalarFromValue( @@ -203,7 +177,7 @@ class MisalignedVectorizationModifier { // >>>>>>>>>>>>> auto last_root_domain_index_shift = - ir_builder.addExpr(last_root_domain_index, shift); + IrBuilder::addExpr(last_root_domain_index, shift); return { vector_size, @@ -220,20 +194,18 @@ class MisalignedVectorizationModifier { kir::IfThenElse* createVectorizeSection( const std::vector& child_loops, const VectorizeData& params) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto vectorized_child_loops = cloneForLoops( child_loops, params.vector_size, nullptr, true, params.shift); // Vectorize Range: [shift - (extent-remainder)) // (last_root_domain_index + shift) < (extent - remainder) - kir::Val* vectorize_cond = ir_builder.ltExpr( + Val* vectorize_cond = IrBuilder::ltExpr( params.last_root_domain_index_shift, params.extent_minus_remainder); kir::Predicate* vectorize_pred = - ir_builder.create(vectorize_cond->as()); + IrBuilder::create(vectorize_cond->as()); kir::IfThenElse* vectorize_ite = - ir_builder.create(vectorize_pred); + IrBuilder::create(vectorize_pred); for (auto cloned_loop : vectorized_child_loops) { vectorize_ite->thenBody().push_back(cloned_loop); @@ -247,20 +219,19 @@ class MisalignedVectorizationModifier { kir::IfThenElse* createInitialSection( const std::vector& child_loops, const VectorizeData& params) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto pre_child_loops = cloneForLoops( child_loops, params.vector_size, params.shift, false, nullptr); // Initial Range: [0 - shift) // last_root_domain_index == 0 - kir::Val* initial_cond = - ir_builder.eqExpr(params.last_root_domain_index, ir_builder.zeroVal()); + Val* initial_cond = IrBuilder::eqExpr( + params.last_root_domain_index, + GpuLower::current()->kernel()->zeroVal()); kir::Predicate* initial_pred = - ir_builder.create(initial_cond->as()); + IrBuilder::create(initial_cond->as()); kir::IfThenElse* initial_ite = - ir_builder.create(initial_pred); + IrBuilder::create(initial_pred); for (auto cloned_loop : pre_child_loops) { initial_ite->thenBody().push_back(cloned_loop); @@ -274,23 +245,21 @@ class MisalignedVectorizationModifier { kir::IfThenElse* createRemainderSection( const std::vector& child_loops, const VectorizeData& params) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto post_child_loops = cloneForLoops( child_loops, params.vector_size, params.remainder, false, params.shift); // Remainder Range: [(extent-remainder) - extent) // (extent - remainder) <= last_root_domain_index + shift < extent - kir::Val* lower_bound = ir_builder.geExpr( + Val* lower_bound = IrBuilder::geExpr( params.last_root_domain_index_shift, params.extent_minus_remainder); - kir::Val* upper_bound = - ir_builder.ltExpr(params.last_root_domain_index_shift, params.extent); - kir::Val* remainder_cond = ir_builder.andExpr(lower_bound, upper_bound); + Val* upper_bound = + IrBuilder::ltExpr(params.last_root_domain_index_shift, params.extent); + Val* remainder_cond = IrBuilder::andExpr(lower_bound, upper_bound); kir::Predicate* remainder_pred = - ir_builder.create(remainder_cond->as()); + IrBuilder::create(remainder_cond->as()); kir::IfThenElse* remainder_ite = - ir_builder.create(remainder_pred); + IrBuilder::create(remainder_pred); for (auto cloned_loop : post_child_loops) { remainder_ite->thenBody().push_back(cloned_loop); @@ -302,8 +271,6 @@ class MisalignedVectorizationModifier { kir::ForLoop* handleMisalignedVectorize( std::vector for_loop_structure, const kir::ForLoop* parent_for_loop) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto child_loops = findChildForLoops(parent_for_loop); // Assumption: All vectorize operations have the same shift @@ -315,17 +282,19 @@ class MisalignedVectorizationModifier { // The parent_for_loop contains allocate, read, compute, write operations const auto new_parent_for_loop = - ir_builder.create(parent_for_loop); + IrBuilder::create(parent_for_loop); // Transfer all expressions except for-loops to new parent for-loop // All expressions are placed at the beginning of the new for-loop - moveExprsExceptForLoops(parent_for_loop, new_parent_for_loop); + copyExprsExceptForLoops(parent_for_loop, new_parent_for_loop); // Get the predicate for all but the last root domain - auto pred_except_last_root_domain = ir_builder.create( - PredicateType::Misaligned, vectorized_expr, ir_builder.trueVal()); + auto pred_except_last_root_domain = IrBuilder::create( + PredicateType::Misaligned, + vectorized_expr, + GpuLower::current()->kernel()->trueVal()); kir::IfThenElse* pred_ite = - ir_builder.create(pred_except_last_root_domain); + IrBuilder::create(pred_except_last_root_domain); new_parent_for_loop->body().push_back(pred_ite); auto constants = createVectorizeConstants( @@ -351,17 +320,17 @@ class MisalignedVectorizationModifier { // Determine that the expression is UnaryOpType::Set AND // the output TensorView domain is vectorized - bool isVectorizeSetOp(kir::ForLoop* fl, kir::Expr* expr) { - if (fl->iter_domain()->parallelType() != + bool isVectorizeSetOp(kir::ForLoop* fl, Expr* expr) { + if (fl->iter_domain()->getParallelType() != ParallelType::MisalignedVectorize) { return false; } - if (expr->isA()) { - auto unaryOp = expr->as(); - if (unaryOp->out()->isA()) { - auto out_tv = unaryOp->out()->as(); - return unaryOp->operation() == UnaryOpType::Set && + if (expr->isA()) { + auto unaryOp = expr->as(); + if (unaryOp->out()->isA()) { + auto out_tv = unaryOp->out()->as(); + return unaryOp->getUnaryOpType() == UnaryOpType::Set && out_tv->domain()->hasVectorize(); } } @@ -374,15 +343,14 @@ class MisalignedVectorizationModifier { // vectorize flag - Do not generate for loop header // shift value - Add shift to global indices generated within for loop std::vector cloneForLoops( - const std::vector& for_loops, - kir::Val* loop_stop, - kir::Val* pred_stop, + const std::vector& for_loops_, + Val* loop_stop, + Val* pred_stop, bool vectorize, - kir::Val* vectorize_shift) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); + Val* vectorize_shift) { std::vector cloned_for_loops; - for (auto fl : for_loops) { + for (auto fl : for_loops_) { auto first_expr = fl->body().exprs().front(); bool has_vectorize_op = isVectorizeSetOp(fl, first_expr); @@ -391,12 +359,12 @@ class MisalignedVectorizationModifier { TORCH_INTERNAL_ASSERT( !has_vectorize_op || fl->body().exprs().size() == 1); - const auto new_loop = ir_builder.create( + const auto new_loop = IrBuilder::create( fl->iter_domain(), fl->index(), - ir_builder.zeroVal(), + GpuLower::current()->kernel()->zeroVal(), loop_stop, - ir_builder.oneVal(), + GpuLower::current()->kernel()->oneVal(), vectorize && has_vectorize_op, vectorize_shift, fl->isUnrollRequired()); @@ -406,9 +374,9 @@ class MisalignedVectorizationModifier { // Predicate the loop body if pred_stop is not null. This is to // make sure the loop itself is completely unrollable. if (pred_stop != nullptr) { - auto body_pred = ir_builder.create( - ir_builder.ltExpr(new_loop->index(), pred_stop)->as()); - auto body_ite = ir_builder.create(body_pred); + auto body_pred = IrBuilder::create( + IrBuilder::ltExpr(new_loop->index(), pred_stop)->as()); + auto body_ite = IrBuilder::create(body_pred); body->push_back(body_ite); body = &body_ite->thenBody(); } @@ -423,7 +391,7 @@ class MisalignedVectorizationModifier { } // Add all expressions except for loops to new parent for loop - void moveExprsExceptForLoops( + void copyExprsExceptForLoops( const kir::ForLoop* for_loop, kir::ForLoop* new_loop) { std::vector loops; @@ -448,10 +416,10 @@ class MisalignedVectorizationModifier { // Find the first vectorize set - either read or write // Add child For-Loop to for_loop_structure // Enable vectorize flag in child For-Loop - kir::Expr* findFirstVectorizedSetOp( + Expr* findFirstVectorizedSetOp( std::vector& for_loop_structure, - const std::vector& for_loops) { - for (auto fl : for_loops) { + const std::vector& for_loops_) { + for (auto fl : for_loops_) { auto first_expr = fl->body().exprs().front(); bool has_vectorize_op = isVectorizeSetOp(fl, first_expr); if (has_vectorize_op) { @@ -463,38 +431,31 @@ class MisalignedVectorizationModifier { } // Get full extent for the inner-most, merged root domain - kir::Val* getVectorizeExtent( - kir::TensorView* producer_tv, - kir::TensorView* consumer_tv) { + Val* getVectorizeExtent(TensorView* producer_tv, TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto consumer_fuser_tv = consumer_tv->fuserTv(); - auto producer_fuser_tv = producer_tv->fuserTv(); - auto p2c = - PairwiseRootDomainMap(producer_fuser_tv, consumer_fuser_tv) - .mapProducerToConsumer( - producer_fuser_tv->domain(), consumer_fuser_tv->domain()); + auto p2c = PairwiseRootDomainMap(producer_tv, consumer_tv) + .mapProducerToConsumer( + producer_tv->domain(), consumer_tv->domain()); auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo( - {consumer_fuser_tv->domain()->domain().begin() + - consumer_fuser_tv->getComputeAtPosition(), - consumer_fuser_tv->domain()->domain().end()}); + {consumer_tv->domain()->domain().begin() + + consumer_tv->getComputeAtPosition(), + consumer_tv->domain()->domain().end()}); auto producer_root_right_of_ca_domains = IterVisitor::getInputsTo( - {producer_fuser_tv->domain()->domain().begin() + - producer_fuser_tv->getComputeAtPosition(), - producer_fuser_tv->domain()->domain().end()}); + {producer_tv->domain()->domain().begin() + + producer_tv->getComputeAtPosition(), + producer_tv->domain()->domain().end()}); - const auto& consumer_contig = consumer_fuser_tv->domain()->contiguity(); - const auto& producer_contig = producer_fuser_tv->domain()->contiguity(); + const auto& consumer_contig = consumer_tv->domain()->contiguity(); + const auto& producer_contig = producer_tv->domain()->contiguity(); - auto producer_root_domain = producer_fuser_tv->getMaybeRFactorDomain(); + auto producer_root_domain = producer_tv->getMaybeRFactorDomain(); // Calculate extent of merged root domains - kir::Val* extent = nullptr; + Val* extent = nullptr; auto consumer_root_idx = - int(consumer_fuser_tv->getMaybeRFactorDomain().size()) - 1; + int(consumer_tv->getMaybeRFactorDomain().size()) - 1; for (int i = int(producer_root_domain.size()) - 1; i >= 0; --i) { auto producer_root_id = producer_root_domain.at(i); @@ -533,11 +494,10 @@ class MisalignedVectorizationModifier { // We now know it's safe to extend the vectorization domain to these // axes. It shouldn't matter whether producer or consumer is used. - auto consumer_extent = gpu_lower->lowerValue(consumer_root_id->extent()); if (extent == nullptr) { - extent = consumer_extent; + extent = consumer_root_id->extent(); } else { - extent = ir_builder.mulExpr(extent, consumer_extent); + extent = IrBuilder::mulExpr(extent, consumer_root_id->extent()); } // If it's not contiguous, extending the vectorization domain @@ -554,57 +514,37 @@ class MisalignedVectorizationModifier { return extent; } - kir::Val* createNamedScalarFromValue( + Val* createNamedScalarFromValue( kir::Scope& body, - kir::Val* val, + Val* val, const std::string& name, bool address = false) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto namedScalar = (address) ? ir_builder.addressExprNamedScalar(name, val) - : ir_builder.setExprNamedScalar(name, val); + auto namedScalar = (address) ? IrBuilder::addressExprNamedScalar(name, val) + : IrBuilder::setExprNamedScalar(name, val); TORCH_INTERNAL_ASSERT(namedScalar->definition() != nullptr); - auto alloc = ir_builder.create( - namedScalar, MemoryType::Local, ir_builder.oneVal()); + auto alloc = IrBuilder::create( + namedScalar, + MemoryType::Local, + GpuLower::current()->kernel()->oneVal()); body.push_back(alloc); body.push_back(namedScalar->definition()); return namedScalar; } - - private: - // We will track which loops in the incoming IR will be replaced and by what - std::unordered_map expr_replacement_map_; - - // A depth-first ordering of nested for loops - // It is used for indexing and predicate generation - std::vector for_loops_structure_; }; } // namespace -std::vector processMisalignedVectorization( - Fusion* fusion, - const std::vector& exprs) { - FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization"); - - MisalignedVectorizationModifier mvm; - mvm.process(exprs); - - std::vector mutated_exprs; - mutated_exprs.reserve(exprs.size()); - for (auto expr : exprs) { - mutated_exprs.push_back( - ir_utils::applyReplacements(mvm.replacementMap(), expr)); - } - - return mutated_exprs; +std::vector processMisalignedVectorization( + const std::vector& exprs) { + return MisalignedVectorizationModifier::processMisalignedVectorization(exprs); } bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl) { for (auto expr : fl->body().exprs()) { if (expr->isA()) { auto child_fl = expr->as(); - if (child_fl->iter_domain()->parallelType() == + if (child_fl->iter_domain()->getParallelType() == ParallelType::MisalignedVectorize) { return true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h index 588d3787752..bd7ae19d93a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h @@ -1,5 +1,5 @@ #pragma once -#include +#include #include @@ -106,9 +106,8 @@ namespace cuda { //! } //! } //! -std::vector processMisalignedVectorization( - Fusion* fusion, - const std::vector& exprs); +std::vector processMisalignedVectorization( + const std::vector& exprs); bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 838d5d85d9e..cd34c56b510 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -7,8 +7,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -23,27 +22,26 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier { +class ConditionalFromPredicateModifier : public kir::IrVisitor { public: - ConditionalFromPredicateModifier(const std::vector& exprs) { + ConditionalFromPredicateModifier() = delete; + + static std::vector fillPredicates(const std::vector& exprs) { + ConditionalFromPredicateModifier cfpm(exprs); + return cfpm.exprs_; + } + + private: + ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - for (auto* expr : exprs) { - handle(expr); - } + kir::IrVisitor::handle(exprs); } - const std::unordered_map& replacementMap() const { - return expr_replacement_map_; - } + using kir::IrVisitor::handle; - private: - void handle(kir::Expr* expr) { - if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (expr != nullptr && expr->predicate() != nullptr) { + void handle(Expr* expr) final { + if (expr != nullptr && expr->predicate() != nullptr) { // Replace expr predicate with bool conditional auto conditional = generateConditional(expr->predicate()); TORCH_INTERNAL_ASSERT(conditional != nullptr); @@ -51,9 +49,11 @@ class ConditionalFromPredicateModifier { TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr); setWritePredicate(expr, conditional); } + + kir::IrVisitor::handle(expr); } - void setWritePredicate(kir::Expr* expr, kir::Bool* read_cond) { + void setWritePredicate(Expr* expr, Bool* read_cond) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); if (write_cond) { @@ -66,46 +66,25 @@ class ConditionalFromPredicateModifier { } } - void handle(kir::ForLoop* fl) { - for_loops_structure_.push_back(fl); - - const auto exprs_copy = fl->body().exprs(); - for (auto expr : exprs_copy) { - handle(expr); - } - - for_loops_structure_.pop_back(); - } - - void handle(kir::IfThenElse* ite) { + void handle(kir::IfThenElse* ite) final { TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr); // If ite already has Bool conditional, handle internal expressions // Otherwise, generate conditional and update predicate - if (ite->predicate()->hasValue()) { - const auto then_exprs_copy = ite->thenBody().exprs(); - for (auto expr : then_exprs_copy) { - handle(expr); - } - - const auto else_exprs_copy = ite->elseBody().exprs(); - for (auto expr : else_exprs_copy) { - handle(expr); - } - } else { + if (!ite->predicate()->hasValue()) { auto conditional = generateConditional(ite->predicate()); TORCH_INTERNAL_ASSERT(conditional != nullptr); - TORCH_INTERNAL_ASSERT(conditional->isA()); + TORCH_INTERNAL_ASSERT(conditional->isA()); // Update bool conditional in-place ite->predicate()->setValue(conditional); - handle(ite); TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr); } + kir::IrVisitor::handle(ite); } // Generate conditional according to PredicateType - kir::Bool* generateConditional(kir::Predicate* pred) { + Bool* generateConditional(kir::Predicate* pred) { switch (pred->predicate_type()) { case PredicateType::Inline: case PredicateType::ReductionWrite: @@ -114,15 +93,16 @@ class ConditionalFromPredicateModifier { case PredicateType::Padding: { return PredicateCompute::getInlinePredicate( pred->expr(), - for_loops_structure_, + for_loops_, pred->thread_pred(), pred->predicate_type()); } case PredicateType::Vectorize: { std::vector outer_loops; kir::ForLoop* vectorized_loop = nullptr; - for (auto loop : for_loops_structure_) { - if (loop->iter_domain()->parallelType() == ParallelType::Vectorize) { + for (auto loop : for_loops_) { + if (loop->iter_domain()->getParallelType() == + ParallelType::Vectorize) { vectorized_loop = loop; break; } else { @@ -134,8 +114,7 @@ class ConditionalFromPredicateModifier { return UnswitchPredicate::get(outer_loops, vectorized_loop); } case PredicateType::Unswitch: { - return UnswitchPredicate::get( - for_loops_structure_, pred->unrolled_loop()); + return UnswitchPredicate::get(for_loops_, pred->unrolled_loop()); } case PredicateType::Manual: { return pred->value(); @@ -145,33 +124,13 @@ class ConditionalFromPredicateModifier { } return nullptr; } - - private: - // We will track which loops in the incoming IR will be replaced and by what - std::unordered_map expr_replacement_map_; - - // A depth-first ordering of nested for loops - // It is used for indexing and predicate generation - std::vector for_loops_structure_; }; } // namespace -std::vector generateConditionalFromPredicate( - Fusion* fusion, - const std::vector& exprs) { - FUSER_PERF_SCOPE("GpuLower::Lower::generateConditionalFromPredicate"); - - ConditionalFromPredicateModifier p2cm(exprs); - - std::vector mutated_exprs; - mutated_exprs.reserve(exprs.size()); - for (auto expr : exprs) { - mutated_exprs.push_back( - ir_utils::applyReplacements(p2cm.replacementMap(), expr)); - } - - return mutated_exprs; +std::vector generateConditionalFromPredicate( + const std::vector& exprs) { + return ConditionalFromPredicateModifier::fillPredicates(exprs); } namespace { @@ -225,17 +184,14 @@ class PredicateAnalyzer : public OptOutDispatch { return needs_predicate_; } - using OptOutDispatch::handle; - void handle(IterDomain* consumer_id) override { // The traversal should have ended if needs_predicate_ was true TORCH_INTERNAL_ASSERT(!needs_predicate_); // If consumer_id is not going to be materialized as a loop (e.g., // broadcast), no need to predicate - const auto gpu_lower = GpuLower::current(); if (consumer_id->isBroadcast() || - gpu_lower->trivialReductionInfo().isDerived(consumer_id)) { + GpuLower::current()->trivialReductionInfo().isDerived(consumer_id)) { return; } @@ -250,7 +206,7 @@ class PredicateAnalyzer : public OptOutDispatch { return; } - handle(consumer_id->definition()); + OptOutDispatch::handle(consumer_id->definition()); } // If it splits the input axis evenly, proceeds to check the input @@ -291,7 +247,7 @@ class PredicateAnalyzer : public OptOutDispatch { } // namespace bool PredicateElimination::needsPredicate(Expr* expr) const { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } @@ -394,7 +350,7 @@ bool PredicateElimination::needsPredicate(Expr* expr) const { } void PredicateElimination::handle(Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return; } @@ -491,7 +447,7 @@ bool PredicateElimination::setReductionInitValue( bool PredicateElimination::canOmitPredicate(const Expr* expr) const { TORCH_INTERNAL_ASSERT(expr != nullptr); - const auto out_tv = ir_utils::getTVOutput(expr); + const auto out_tv = ir_utils::getTvOutput(expr); TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression"); // No need to predicate local tensors to which a scalar is assigned if (out_tv->getMemoryType() == MemoryType::Local) { @@ -508,38 +464,17 @@ bool PredicateElimination::canOmitPredicate(const Expr* expr) const { return false; } -bool PredicateElimination::canOmitPredicate(const kir::Expr* kir_expr) const { - TORCH_INTERNAL_ASSERT(kir_expr != nullptr); - const auto out_tv = ir_utils::getTVOutput(kir_expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression"); - // No need to predicate local tensors to which a scalar is assigned - if (out_tv->memoryType() == MemoryType::Local) { - if (auto uop = dynamic_cast(kir_expr)) { - if (uop->operation() == UnaryOpType::Set && uop->in()->isScalar()) { - return true; - } - } - } - const auto fuser_tv = out_tv->fuserTv(); - if (fuser_tv == nullptr) { - return false; - } - return canOmitPredicate(fuser_tv->definition()); -} - -kir::Val* PredicateElimination::getInitValue(TensorView* tv) const { +Val* PredicateElimination::getInitValue(TensorView* tv) const { auto it = init_value_map_.find(tv); if (it == init_value_map_.end()) { return nullptr; } - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); auto init_val = it->second; if (init_val == nullptr) { // No reduction restriction. Just use zero - return ir_builder.zeroVal(); + return GpuLower::current()->kernel()->zeroVal(); } else { - return gpu_lower->lowerValue(init_val); + return init_val; } } diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.h b/torch/csrc/jit/codegen/cuda/lower_predicate.h index 393d0fa5c18..c0a1f702f7b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.h @@ -1,5 +1,5 @@ #pragma once -#include +#include #include #include @@ -13,9 +13,8 @@ namespace cuda { //! Update predicates with valid bool conditionals //! -std::vector generateConditionalFromPredicate( - Fusion* fusion, - const std::vector& exprs); +std::vector generateConditionalFromPredicate( + const std::vector& exprs); class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { public: @@ -26,13 +25,8 @@ class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { //! \param expr Tensor expression bool canOmitPredicate(const Expr* expr) const; - //! True if expr does not need a predicate - //! - //! \param expr KIR tensor expr - bool canOmitPredicate(const kir::Expr* expr) const; - //! Value to initialize out-of-bound regions - kir::Val* getInitValue(TensorView* tv) const; + Val* getInitValue(TensorView* tv) const; //! Dump to string for debugging std::string toString() const; @@ -40,7 +34,7 @@ class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { private: using IterVisitor::handle; - void handle(Expr* expr) override; + void handle(Expr* expr) final; //! Set a value to initialize out-of-bound regions bool setDefaultInitValue(TensorView* tv); diff --git a/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp new file mode 100644 index 00000000000..582b6d91d06 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp @@ -0,0 +1,288 @@ +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { +// Going to generate a map of tensor view root domain extents to reduce the +// number used during lowering. For example if we have: +// +// T2[i0, i1] = T1[i0, i1] + T2[i2, i3] +// +// We know it would be safe to use: +// +// T2[i0, i1] = T1[i0, i1] + T2[i0, i1] +// +// And that way we don't generate T2.size[0] and T2.size[1], instead we will +// reuse T1.size[0] and T1.size[1] +// This is important when doing CSE as T2 and T1 would otherwise look like +// they're using different values, even though we know they're the same +// +// There's some duplicate logic here that's in computeAt map, but it's not so +// concice there to pull out. May want to consider making this mapping its own +// class especially as it may be useful during scheduling. +std::unordered_map getSimplificationMap(Fusion* fusion) { + std::list> disjoint_root_sets; + std::unordered_map*> + id_to_disjoint_root_set; + + auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set]( + IterDomain* id0, IterDomain* id1) { + if (id0->isBroadcast() || id1->isBroadcast()) { + return; + } + + auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0); + auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1); + bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end(); + bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end(); + + if (set_0_found && set_1_found) { + if (disjoint_set_0_it->second == disjoint_set_1_it->second) { + return; + } + // merge second disjoint set into first + auto* set_0 = disjoint_set_0_it->second; + auto* set_1 = disjoint_set_1_it->second; + for (auto id : *set_1) { + set_0->emplace(id); + id_to_disjoint_root_set[id] = set_0; + } + // remove second set from disjoint_root_sets + disjoint_root_sets.erase(std::find( + disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1)); + } else if (set_0_found || set_1_found) { + auto existing_set = + set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second; + auto to_add_id = set_0_found ? id1 : id0; + existing_set->emplace(to_add_id); + id_to_disjoint_root_set[to_add_id] = existing_set; + // add entry into existing set + } else { + // create new set entry + disjoint_root_sets.emplace_back(std::unordered_set()); + auto* new_set = &disjoint_root_sets.back(); + new_set->emplace(id0); + new_set->emplace(id1); + id_to_disjoint_root_set[id0] = new_set; + id_to_disjoint_root_set[id1] = new_set; + } + }; + + auto fusion_vals = fusion->usedMathVals(); + for (auto producer_tv : ir_utils::filterByType(fusion_vals)) { + auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); + for (auto consumer_tv : consumer_tvs) { + auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); + auto c2p_root_map = pairwise_map.mapConsumerToProducer( + consumer_tv->domain(), producer_tv->domain()); + for (auto entry : c2p_root_map) { + auto c_id = entry.first; + auto p_id = entry.second; + map_root_ids(p_id, c_id); + } + } + } + + // Map each set to an input ID (if it exists) that has the smallest ->name() + // entry value + std::unordered_map*, IterDomain*> + set_to_input_id; + + // Loop over the root domains, of the inputs to the fusion. Pick an input ID + // to use as the representative ID of the collected sets. Only consider inputs + // as those are the ones that map to values like "T0.size[1]". They are he + // ID's that propagated their extents into the problem. We could also check + // the outputs as we do have C++ examples of using output dimensions for the + // problem size instead of inputs. However, we don't do anything where we can + // translate to those kinds of kernels integrated into PyTorch. + for (auto input_tv : ir_utils::filterByType(fusion->inputs())) { + for (auto id : + TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) { + auto id_set_it = id_to_disjoint_root_set.find(id); + if (id_set_it == id_to_disjoint_root_set.end()) { + continue; + } + auto* id_set = id_set_it->second; + if (set_to_input_id.find(id_set) == set_to_input_id.end()) { + set_to_input_id[id_set] = id; + } else { + auto input_id_of_set = set_to_input_id.at(id_set); + // Swap id's if new name is less than previously set + bool swap_ids = id->name() < input_id_of_set->name(); + // If new id is a const scalar but previously was'nt use the const + // scalar + swap_ids = swap_ids || + (id->extent()->isConstScalar() && + !input_id_of_set->extent()->isConstScalar()); + // If previous scalar was const and new isn't, don't swap + swap_ids = swap_ids && + !(input_id_of_set->extent()->isConstScalar() && + !id->extent()->isConstScalar()); + + if (swap_ids) { + set_to_input_id[id_set] = id; + } + } + } + } + + // Finally make map from ID extents to the representitive ID extent. + std::unordered_map extent_to_min_input_id_extent; + for (auto entry : set_to_input_id) { + auto* set = entry.first; + auto input_id = entry.second; + for (auto id : *set) { + extent_to_min_input_id_extent[id->extent()] = input_id->extent(); + } + } + return extent_to_min_input_id_extent; +} + +std::vector allLeafOuts(Fusion* fusion) { + auto exprs = StmtSort::getExprs(fusion, true); + std::unordered_set inputs; + std::unordered_set outputs; + std::vector ordered_outputs; + for (auto expr : exprs) { + inputs.insert(expr->inputs().begin(), expr->inputs().end()); + outputs.insert(expr->outputs().begin(), expr->outputs().end()); + ordered_outputs.insert( + ordered_outputs.end(), expr->outputs().begin(), expr->outputs().end()); + } + for (auto input : inputs) { + outputs.erase(input); + } + + std::vector ordered_leaf_outs; + for (auto out : ordered_outputs) { + if (outputs.find(out) != outputs.end()) { + ordered_leaf_outs.push_back(out); + } + } + return ordered_leaf_outs; +} + +class ValReplacementMutator : private OptOutMutator { + public: + ValReplacementMutator( + Fusion* fusion, + const std::unordered_map& replacement_map) + : replacement_map_(replacement_map) { + FusionGuard fg(fusion); + + // Welford makes this a little annoying since it holds a count which is + // typically not used by anything else. If we don't grab that count, then it + // would be a tensorview that doesn't get updated extents. Therefore, first + // grab all leaves towards outputs and grab stmts from there. + auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); + for (auto stmt : stmts) { + mutate(stmt); + } + } + + private: + using OptOutMutator::mutate; + void mutate(Val* val) final { + if (replacement_map_.find(val) == replacement_map_.end()) { + return OptOutMutator::mutate(val); + } + auto replaced_val = replacement_map_.at(val); + registerMutation(val, replaced_val); + } + + const std::unordered_map& replacement_map_; +}; + +} // namespace + +void replaceSymbolicSizes(Fusion* fusion) { + FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes"); + std::unordered_map tensor_dim_map; + + // Grab inputs and outputs + std::vector inputs_and_outputs; + for (auto val : fusion->inputs()) { + if (ir_utils::isTV(val)) { + inputs_and_outputs.push_back(val->as()); + } + } + // Symbolic size is necessary for outputs if there are no inputs. + // Otherwise infer output sizes from the inputs via expression evaluation. + if (fusion->inputs().empty()) { + for (auto val : fusion->outputs()) { + if (ir_utils::isTV(val)) { + inputs_and_outputs.push_back(val->as()); + } + } + } + + // Generate map for all tensorview root domain values to map them to symbolic + // values. i.e. T0->getRootDomain()[0] would map to a named scalar + // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir. + for (TensorView* tv : inputs_and_outputs) { + // Replace the domain with one based on Ti.size[j] + const std::vector& root_td = tv->getRootDomain(); + + size_t dim = 0; + for (auto id : root_td) { + Val* orig_size = id->extent(); + + // Output sizes could have reduction axes, which isn't what gets output. + // NOLINTNEXTLINE(bugprone-branch-clone) + if (id->isReduction() || + (id->getIterType() == IterType::BroadcastWithoutStride)) { + continue; + } else if ( + id->isRFactorProduct() || + // NOLINTNEXTLINE(bugprone-branch-clone) + (id->getIterType() == IterType::BroadcastWithStride) || + orig_size->isConstScalar()) { + dim++; + continue; + } + + // Currently turn off this part for inputs of segmented fusion, + // since FusionKernelRuntime will provide these as integer inputs + if (tensor_dim_map.find(orig_size) == tensor_dim_map.end() && + !orig_size->isFusionInput() && !orig_size->isConstScalar()) { + std::stringstream ss; + ss << "T" << tv->name() << ".size[" << dim++ << "]"; + tensor_dim_map[orig_size] = IrBuilder::create( + ss.str(), orig_size->getDataType().value()); + } else { + dim++; + } + } + } + + // Use a minimal number of sizes from provided tensors. + auto extent_simplification_map = getSimplificationMap(fusion); + for (auto extent_entry : extent_simplification_map) { + auto orig_extent = extent_entry.first; + auto simplified_extent = extent_entry.second; + if (tensor_dim_map.count(orig_extent)) { + if (tensor_dim_map.count(simplified_extent)) { + tensor_dim_map[orig_extent] = tensor_dim_map[simplified_extent]; + } else { + tensor_dim_map[orig_extent] = simplified_extent; + } + } + } + + // Run mutation on the fusion with the tensor_dim_map + ValReplacementMutator(fusion, tensor_dim_map); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_replace_size.h b/torch/csrc/jit/codegen/cuda/lower_replace_size.h new file mode 100644 index 00000000000..81cee9f6ffe --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_replace_size.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// TensorViews are all based on symbolic sizes. When we first initialize them +// we don't know if they're inputs or outputs which would mean that they have +// runtime shapes. Intermediate tensors (those not going to global memory) do +// not have this information. Since we need to have the correct information in +// the kernel being fetched for shapes, we want to replace input and output +// tensors to reference the runtime structure containing sizes. +void replaceSymbolicSizes(Fusion*); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 8a4f6980e01..ca451ee5f97 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include #include #include @@ -19,19 +17,17 @@ namespace fuser { namespace cuda { void ShiftPredicateInserter::insert( - kir::Expr* expr, + Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, bool within_unswitch) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - kir::TensorView* out_tv = ir_utils::getTVOutput(expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output"); + TensorView* out_tv = ir_utils::getTvOutput(expr); + TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); - TensorView* out_fuser_tv = out_tv->fuserTv(); const bool needs_shift_predicate = - gpu_lower->haloInfo().needsShiftPredicate(out_fuser_tv->definition()); + gpu_lower->haloInfo().needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { return; } @@ -48,12 +44,12 @@ void ShiftPredicateInserter::insert( kir::Predicate* thread_pred_expr = nullptr; if (within_unswitch) { - thread_pred_expr = ir_builder.create(thread_pred); + thread_pred_expr = IrBuilder::create(thread_pred); } kir::Predicate* shift_pred = within_unswitch ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::Shift, expr, thread_pred); // If the expr involves a thread-block barrier, set the predicate of @@ -64,7 +60,7 @@ void ShiftPredicateInserter::insert( return; } - auto shift_ite = ir_builder.create(shift_pred); + auto shift_ite = IrBuilder::create(shift_pred); auto& scope = loops.back()->body(); @@ -83,56 +79,33 @@ void ShiftPredicateInserter::insert( } // Padding by zero - kir::Predicate* padding_pred = ir_builder.create( + kir::Predicate* padding_pred = IrBuilder::create( PredicateType::Padding, expr, thread_pred); - auto bounds_ite = ir_builder.create(padding_pred); + auto bounds_ite = IrBuilder::create(padding_pred); const int pad_value = 0; - auto pad_expr = ir_builder.create( - UnaryOpType::Set, out_tv, ir_builder.create(pad_value)); + auto pad_expr = IrBuilder::create( + UnaryOpType::Set, out_tv, IrBuilder::create(pad_value)); bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); } -AxisHaloInfo::AxisHaloInfo() { - auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - setWidth(0, ir_builder.zeroVal()); - setWidth(1, ir_builder.zeroVal()); -} - -kir::Int* AxisHaloInfo::width() const { - auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - return ir_builder.addExpr(width(0), width(1))->as(); +int AxisHaloInfo::width() const { + return width(0) + width(1); } -kir::Int* AxisHaloInfo::width(int pos) const { +int AxisHaloInfo::width(int pos) const { TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2); - TORCH_INTERNAL_ASSERT(widths_[pos] != nullptr); return widths_[pos]; } -void AxisHaloInfo::setWidth(int pos, kir::Int* width) { +void AxisHaloInfo::setWidth(int pos, int width) { TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2); widths_[pos] = width; } -void AxisHaloInfo::merge(int pos, kir::Int* other) { - auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto cur = width(pos); - kir::Int* new_width = nullptr; - if (cur->isConst() && other->isConst()) { - new_width = ir_builder.create( - std::max(cur->value().value(), other->value().value())); - } else if (cur->isZeroInt()) { - new_width = other; - } else if (other->isZeroInt()) { - new_width = cur; - } else { - new_width = ir_builder.maxExpr(width(pos), other)->as(); - } +void AxisHaloInfo::merge(int pos, int other) { + auto new_width = std::max(width(pos), other); setWidth(pos, new_width); } @@ -144,13 +117,12 @@ void AxisHaloInfo::merge(const AxisHaloInfo& other) { bool AxisHaloInfo::hasHalo() const { return std::any_of( - widths_.begin(), widths_.end(), [](auto w) { return !w->isZeroInt(); }); + widths_.begin(), widths_.end(), [](auto w) { return w != 0; }); } std::string AxisHaloInfo::toString() const { std::stringstream ss; - ss << "<" << kir::toString(width(0)) << ", " << kir::toString(width(1)) - << ">"; + ss << "<" << width(0) << ", " << width(1) << ">"; return ss.str(); } @@ -158,38 +130,21 @@ bool HaloInfo::hasRootAxisInfo(IterDomain* id) const { return root_axis_map_.find(id) != root_axis_map_.end(); } -bool HaloInfo::hasRootAxisInfo(kir::IterDomain* id) const { - return kir_root_axis_map_.find(id) != kir_root_axis_map_.end(); -} - const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const { + // TODO: Enable this check, was failing in many tests + // TORCH_INTERNAL_ASSERT( + // id->definition() == nullptr || id->isRFactorProduct(), + // "Invalid IterDomain: ", + // id); auto it = root_axis_map_.find(id); TORCH_INTERNAL_ASSERT( - it != root_axis_map_.end(), "Halo root axis info not found for ", id); - return it->second; -} - -AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - return const_cast( - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(this)->getRootAxisInfo(id)); -} - -const AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) const { - TORCH_INTERNAL_ASSERT( - id->definition() == nullptr || id->isRFactorProduct(), - "Invalid IterDomain: ", - id); - auto it = kir_root_axis_map_.find(id); - TORCH_INTERNAL_ASSERT( - it != kir_root_axis_map_.end(), + it != root_axis_map_.end(), "Halo root axis info not found for ", - kir::toString(id)); + id->toString()); return it->second; } -AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) { +AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) return const_cast( // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -200,9 +155,6 @@ void HaloInfo::setRootAxisInfo( IterDomain* id, const AxisHaloInfo& root_axis_info) { root_axis_map_[id] = root_axis_info; - kir_root_axis_map_ - [GpuLower::current()->lowerValue(id)->as()] = - root_axis_info; initializeFromRootAxisInfo(id); return; @@ -283,9 +235,6 @@ void HaloInfo::propagateRootAxisInfo( const auto& c_root = consumer->getRootDomain(); - auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - for (const auto i : c10::irange(c_root.size())) { auto c_id = c_root[i]; auto it = c2p.find(c_id); @@ -332,31 +281,19 @@ void HaloInfo::propagateRootAxisInfo( p_info.merge(c_info); } else { int pos = (offset > 0) ? 0 : 1; - p_info.merge( - pos, - ir_builder.addExpr(c_info.width(pos), std::abs(offset)) - ->as()); + p_info.merge(pos, c_info.width(pos) + std::abs(offset)); } } else if (auto gather_op = dynamic_cast(expr)) { - const auto window_dim = - gpu_lower->lowerValue(gather_op->windowShape()[i]); - if (window_dim->isOneInt()) { + const auto window_dim = gather_op->windowShape()[i]; + if (window_dim == 1) { p_info.merge(c_info); continue; } - const auto& pad_dim = gather_op->padWidth()[i]; - const auto pad_dim0 = gpu_lower->lowerValue(pad_dim[0])->as(); - p_info.merge( - 0, ir_builder.addExpr(c_info.width(0), pad_dim0)->as()); + const auto pad_dim0 = gather_op->padWidth()[i][0]; + p_info.merge(0, c_info.width(0) + pad_dim0); // The right-side halo is propagated as: // consumer_right_halo + (window_dim - 1 - left_padding) - p_info.merge( - 1, - ir_builder - .subExpr( - ir_builder.addExpr(c_info.width(1), window_dim), - ir_builder.addExpr(pad_dim0, 1)) - ->as()); + p_info.merge(1, c_info.width(1) + window_dim - 1 - pad_dim0); } else { p_info.merge(c_info); } @@ -390,29 +327,30 @@ void HaloInfo::initializeFromRootAxisInfo(IterDomain* id) { TORCH_INTERNAL_ASSERT(hasRootAxisInfo(id)); auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); const auto& halo_info = getRootAxisInfo(id); auto halo_width = halo_info.width(); if (!halo_info.hasHalo()) { - halo_width_map_[id] = ir_builder.zeroVal(); + setHaloWidth(id, 0); return; } auto expanded_extent = - ir_builder.addExpr(gpu_lower->lowerValue(id->extent()), halo_width); - kir_extent_map_[gpu_lower->lowerValue(id)->as()] = - expanded_extent; + IrBuilder::addExpr(id->extent(), IrBuilder::create(halo_width)); + extent_map_[id] = expanded_extent; halo_width_map_[id] = halo_width; inheritance_map_[id] = {id}; } +void HaloInfo::setHaloWidth(IterDomain* id, int halo_width) { + halo_width_map_[id] = halo_width; +} + // Propagate extent information from root axes to descendants void HaloInfo::build(TensorDomain* td) { auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); auto exprs = DependencyCheck::getAllExprsBetween( {td->getMaybeRFactorDomain().begin(), td->getMaybeRFactorDomain().end()}, @@ -459,33 +397,29 @@ void HaloInfo::build(TensorDomain* td) { auto in_id = split->in(); - const auto& halo_width_it = halo_width_map_.find(in_id); - // If no halo info is found, nothing needs to be done. This ID // must be an ancestor of a domain set by setRootAxisInfo. - if (halo_width_it == halo_width_map_.end()) { + if (!hasHaloWidth(in_id)) { continue; } - const auto halo_width = halo_width_it->second; + const auto halo_width = getHaloWidth(in_id); - if (halo_width->isZeroInt()) { - halo_width_map_.insert({split->outer(), halo_width}); - halo_width_map_.insert({split->inner(), halo_width}); + if (halo_width == 0) { + setHaloWidth(split->outer(), 0); + setHaloWidth(split->inner(), 0); continue; } // propagate to inner domain auto out_id = split->inner(); - auto expanded_extent = ir_builder.addExpr( - gpu_lower->lowerValue(out_id->extent()), halo_width); - kir_extent_map_.insert( - {gpu_lower->lowerValue(out_id)->as(), - expanded_extent}); + auto expanded_extent = + SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width); + extent_map_.insert({out_id, expanded_extent}); - halo_width_map_.insert({split->outer(), ir_builder.zeroVal()}); - halo_width_map_.insert({split->inner(), halo_width}); + setHaloWidth(split->outer(), 0); + setHaloWidth(split->inner(), halo_width); insertToInheritanceMap(td, in_id, split->inner()); } else if (auto merge = dynamic_cast(expr)) { @@ -495,25 +429,24 @@ void HaloInfo::build(TensorDomain* td) { auto outer_extent = getExtent(merge->outer()); if (inner_extent != nullptr || outer_extent != nullptr) { if (inner_extent == nullptr) { - inner_extent = gpu_lower->lowerValue(merge->inner()->extent()); + inner_extent = merge->inner()->extent(); } else { insertToInheritanceMap(td, merge->inner(), merge->out()); } if (outer_extent == nullptr) { - outer_extent = gpu_lower->lowerValue(merge->outer()->extent()); + outer_extent = merge->outer()->extent(); } else { insertToInheritanceMap(td, merge->outer(), merge->out()); } - auto expanded_extent = ir_builder.mulExpr(outer_extent, inner_extent); - kir_extent_map_.insert( - {gpu_lower->lowerValue(merge->out())->as(), - expanded_extent}); + auto expanded_extent = + SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent); + extent_map_.insert({merge->out(), expanded_extent}); // Splitting the output of this merge is not allowed, so // remember it merged_shifted_ids.insert(merge->out()); // Note that halo_width_map_ is not updated } else { - halo_width_map_.insert({merge->out(), ir_builder.zeroVal()}); + setHaloWidth(merge->out(), 0); } } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); @@ -579,7 +512,7 @@ void HaloInfo::validate(TensorView* tv) const { bool shared_mem_needed = false; for (auto use : tv->uses()) { - if (!ir_utils::isTVOp(use)) { + if (!ir_utils::isTvOp(use)) { continue; } if (use->isA() || use->isA()) { @@ -629,21 +562,16 @@ void HaloInfo::validate(TensorView* tv) const { return; } -kir::Val* HaloInfo::getExtent(IterDomain* id) const { - auto kir_id = GpuLower::current()->lowerValue(id)->as(); - return getExtent(kir_id); -} - -kir::Val* HaloInfo::getExtent(kir::IterDomain* id) const { - auto it = kir_extent_map_.find(id); - if (it != kir_extent_map_.end()) { +Val* HaloInfo::getExtent(IterDomain* id) const { + auto it = extent_map_.find(id); + if (it != extent_map_.end()) { return it->second; } else { return nullptr; } } -kir::Int* HaloInfo::getHaloWidth(IterDomain* id) const { +int HaloInfo::getHaloWidth(IterDomain* id) const { auto it = halo_width_map_.find(id); TORCH_INTERNAL_ASSERT(it != halo_width_map_.end()); return it->second; @@ -736,63 +664,11 @@ bool extentCompare( } // namespace bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const { - auto cmp = [](kir::Int* x, kir::Int* y) { - if (x == y) { - return true; - } - auto xv = x->value(); - auto yv = y->value(); - return xv.has_value() && yv.has_value() && xv.value() <= yv.value(); - }; - return extentCompare(*this, id1, id2, cmp); + return extentCompare(*this, id1, id2, std::less_equal<>()); } bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const { - // Returns true only when x and y are proven to be the same. The - // analysis is not comprehensive and can prove in rather trivial - // cases only. Specifically: - // - x and y are the same pointers - // - Both have static values and they are the same - // - Both are defined by the same expression and the inputs are - // proven to be equal - std::function cmp = [&](kir::Int* x, - kir::Int* y) { - if (x == y) { - return true; - } - - auto xv = x->value(); - auto yv = y->value(); - if (xv.has_value() && yv.has_value() && xv.value() == yv.value()) { - return true; - } - - // Check if both are defined by an expression of the same type. If - // so, recursively check the input operands. - auto x_def = x->definition(); - auto y_def = y->definition(); - if (x_def && y_def && - ((x_def->isA() && y_def->isA() && - x_def->as()->operation() == - y_def->as()->operation()) || - (x_def->isA() && y_def->isA() && - x_def->as()->operation() == - y_def->as()->operation()))) { - for (const auto i : c10::irange(x_def->inputs().size())) { - auto x_input = dynamic_cast(x_def->inputs()[i]); - auto y_input = dynamic_cast(y_def->inputs()[i]); - // Both must be kir::Int - TORCH_INTERNAL_ASSERT(x_input && y_input); - if (!cmp(x_input, y_input)) { - return false; - } - } - return true; - } - - return false; - }; - return extentCompare(*this, id1, id2, cmp); + return extentCompare(*this, id1, id2, std::equal_to<>()); } std::string HaloInfo::toString() const { @@ -822,16 +698,19 @@ std::string HaloInfo::toString() const { } bool HaloInfo::needsShiftPredicate(Expr* expr) const { - auto consumer_td = ir_utils::getTVOutput(expr)->domain(); - auto shift_expr = dynamic_cast(expr); - auto gather_expr = dynamic_cast(expr); + // In lowering shift and gather turn into a unary op. We really need the shift + // expr. Do a round about trick to grab it: + auto tv_out = ir_utils::getTvOutput(expr); + auto consumer_td = tv_out->domain(); + auto shift_expr = dynamic_cast(tv_out->definition()); + auto gather_expr = dynamic_cast(tv_out->definition()); for (const auto i : c10::irange(consumer_td->getRootDomain().size())) { auto consumer_id = consumer_td->getRootDomain()[i]; const auto consumer_halo_info = getRootAxisInfo(consumer_id); if (consumer_halo_info.hasHalo() || (shift_expr != nullptr && shift_expr->offset(i) != 0 && !consumer_id->isBroadcast()) || - (gather_expr != nullptr && !gather_expr->windowShape()[i]->isOneInt() && + (gather_expr != nullptr && gather_expr->windowShape()[i] != 1 && !consumer_id->isBroadcast())) { return true; } @@ -839,13 +718,6 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const { return false; } -bool HaloInfo::needsShiftPredicate(kir::Expr* expr) const { - const auto out_tv = expr->outputs()[0]->as(); - auto fuser_expr = out_tv->fuserTv()->definition(); - TORCH_INTERNAL_ASSERT(fuser_expr != nullptr); - return needsShiftPredicate(fuser_expr); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index 378709ca443..c0fea8c1ead 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -16,16 +16,14 @@ namespace cuda { //! Auxiliary class to represent information about halo of an axis class AxisHaloInfo { public: - AxisHaloInfo(); - //! Width of halo. //! //! pos is either 0 or 1. The width of halo at offset zero is set //! when pos is 0. - kir::Int* width(int pos) const; + int width(int pos) const; //! Sum of the widths of both widths - kir::Int* width() const; + int width() const; const auto& widths() const { return widths_; @@ -34,10 +32,10 @@ class AxisHaloInfo { //! Set the halo width of either side. //! pos is either 0 or 1. The width of halo at offset zero is set //! when pos is 0. - void setWidth(int pos, kir::Int* width); + void setWidth(int pos, int width); //! Extend the halo width to account for another axis. - void merge(int pos, kir::Int* other); + void merge(int pos, int other); //! Extend the halo width to account for another axis. void merge(const AxisHaloInfo& other); @@ -53,7 +51,7 @@ class AxisHaloInfo { //! widths_[0] is non-zero and designates the size of the //! halo. Similarly, non-zero widths_[1] means the axis has halo at //! the other end of the axis. - std::array widths_ = {nullptr, nullptr}; + std::array widths_ = {0, 0}; }; //! Helper class for lowering tensors with halo. Only valid at the @@ -77,7 +75,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! Returns true if id has the root halo information set by //! setRootAxisInfo. bool hasRootAxisInfo(IterDomain* id) const; - bool hasRootAxisInfo(kir::IterDomain* id) const; //! Returns the registed AxisHaloInfo of a root axis. //! @@ -85,9 +82,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! non-root axes. const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const; AxisHaloInfo& getRootAxisInfo(IterDomain* id); - //! KIR version - const AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id) const; - AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id); //! Query if an axis has a halo width. //! @@ -98,12 +92,11 @@ class TORCH_CUDA_CU_API HaloInfo { //! //! It's an error if queried for an axis with no halo width //! information. - kir::Int* getHaloWidth(IterDomain* id) const; + int getHaloWidth(IterDomain* id) const; //! Returns an extent if id is extended for halo. Nullptr is //! returned otherwise. - kir::Val* getExtent(IterDomain* id) const; - kir::Val* getExtent(kir::IterDomain* id) const; + Val* getExtent(IterDomain* id) const; //! Returns all child domains of a root domain that inherits the //! halo of the root domain. @@ -135,7 +128,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! interior and another for padding. Predicate insertion is done in //! the ShiftPredicateInserter class below. bool needsShiftPredicate(Expr* expr) const; - bool needsShiftPredicate(kir::Expr* expr) const; std::string toString() const; @@ -166,14 +158,14 @@ class TORCH_CUDA_CU_API HaloInfo { //! Validate shift usage void validate(TensorView* td) const; + void setHaloWidth(IterDomain* id, int halo_width); + private: //! Halo information of root axes std::unordered_map root_axis_map_; - //! KIR version - std::unordered_map kir_root_axis_map_; //! Halo-extended extents. No mapping for axes without halo extension - std::unordered_map kir_extent_map_; + std::unordered_map extent_map_; //! The halo width of an axis. //! @@ -209,7 +201,7 @@ class TORCH_CUDA_CU_API HaloInfo { //! inner axis is merged with another axis of extent M, we know that //! the extent of the resulting output axis is 5*M, but we don't //! create its mapping. - std::unordered_map halo_width_map_; + std::unordered_map halo_width_map_; //! Mappings from root domains to child domains that inherit halo std::unordered_map> @@ -224,9 +216,9 @@ class ShiftPredicateInserter { //! the usual predicated expression, so the insertion is also done //! here. static void insert( - kir::Expr* expr, + Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, bool within_unswitch); }; diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index a7f8768883d..8721490feb7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -17,55 +16,49 @@ namespace cuda { namespace { -kir::Bool* getPredicatePerParallelType( +Bool* getPredicatePerParallelType( ParallelType pt, const ThreadPredicateMap::PredicateInfo& pred_info) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); // If pt is not used or is proven to be one, no need to predicate. if (pt_dim == nullptr || pt_dim->isOneInt()) { - return ir_builder.trueVal(); + return GpuLower::current()->kernel()->trueVal(); } - // When BID needs to be predicated, that means it's an output of a grid // reduction and only the last block index in that dimension has the right // value from the grid reduce. if (isParallelTypeBlockDim(pt) && pred_info.limited_types.get(pt)) { - return ir_builder - .eqExpr( - kir::NamedScalar::getParallelIndex(pt), - ir_builder.subExpr( - kir::NamedScalar::getParallelDim(pt), ir_builder.oneVal())) - ->as(); + return SimplifyingIrBuilder::eqExpr( + NamedScalar::getParallelIndex(pt), + SimplifyingIrBuilder::subExpr( + NamedScalar::getParallelDim(pt), + GpuLower::current()->kernel()->oneVal())) + ->as(); } // Otherwise, only thread of index 0 executes the computation - return ir_builder - .eqExpr(kir::NamedScalar::getParallelIndex(pt), ir_builder.zeroVal()) - ->as(); + return SimplifyingIrBuilder::eqExpr( + NamedScalar::getParallelIndex(pt), + GpuLower::current()->kernel()->zeroVal()) + ->as(); } } // namespace -kir::Bool* ThreadPredicateMap::getPredicateFromPredicateInfo( +Bool* ThreadPredicateMap::getPredicateFromPredicateInfo( const ThreadPredicateMap::PredicateInfo& pred_info) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - const auto pred_types = pred_info.limited_types | pred_info.redundant_types; if (pred_types.none()) { - return ir_builder.trueVal(); + return GpuLower::current()->kernel()->trueVal(); } - kir::Bool* pred = nullptr; - + Bool* pred = nullptr; for (const auto pt : pred_types) { const auto tp = getPredicatePerParallelType(pt, pred_info); - pred = ir_builder.andExpr(pred, tp)->as(); + pred = SimplifyingIrBuilder::andExpr(pred, tp)->as(); } - TORCH_INTERNAL_ASSERT(pred != nullptr); return pred; @@ -191,7 +184,9 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { if (id->isReduction()) { id_reductions.set(id->getParallelType()); } - if (id->isBroadcast()) { + if (id->isBroadcast() && + GpuLower::current()->concretizedBroadcastDomains().isConcretized( + id)) { id_bcasts.set(id->getParallelType()); } } @@ -302,7 +297,7 @@ void ThreadPredicateMap::insert( thread_predicates_.insert({tv, pred_info}); } -kir::Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const { +Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const { TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find ", tv); auto pred_info = getPredicateInfo(tv); return getPredicateFromPredicateInfo(pred_info); @@ -326,7 +321,8 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains( const bool output_smem = tv->getMemoryType() == MemoryType::Shared; for (auto id : iter_domains) { - if (!id->isBroadcast()) { + if (!id->isBroadcast() || + !GpuLower::current()->concretizedBroadcastDomains().isConcretized(id)) { continue; } if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h index 256e0385aeb..0d7a2685b32 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h @@ -1,7 +1,7 @@ #pragma once -#include +#include #include #include @@ -69,7 +69,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { ParallelTypeBitmap getPredicatedParallelTypes(const TensorView* tv) const; //! Returns a Bool predicate for a given TensorView. - kir::Bool* getPredicate(const TensorView* tv) const; + Bool* getPredicate(const TensorView* tv) const; //! Returns a ParallelTypeBitmap representing which domain needs //! blockBroadcast. @@ -81,7 +81,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { void print() const; //! Generate a Bool value from PredicateInfo. - static kir::Bool* getPredicateFromPredicateInfo( + static Bool* getPredicateFromPredicateInfo( const ThreadPredicateMap::PredicateInfo& pred_info); private: diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp new file mode 100644 index 00000000000..ab62530591a --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp @@ -0,0 +1,119 @@ +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +void ConcretizedBroadcastDomains::build(Fusion* fusion) { + // Initialize the origin map with input broadcast domains + for (const auto fusion_input_tv : + ir_utils::filterByType(fusion->inputs())) { + for (auto root_id : fusion_input_tv->getRootDomain()) { + if (root_id->isBroadcast()) { + broadcast_origin_map_.emplace( + root_id, std::unordered_set({root_id})); + } + } + } + traverse(fusion); +} + +bool ConcretizedBroadcastDomains::isConcretized(IterDomain* id) const { + auto it = concretized_domains_.find(id); + return it != concretized_domains_.end(); +} + +void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) { + // Create a new entry for each of new broadcast domains + auto out = bop->out()->as(); + for (const auto i : c10::irange(out->getRootDomain().size())) { + if (bop->getBroadcastDimFlags().at(i)) { + auto new_bcast_id = out->getRootDomain().at(i); + broadcast_origin_map_.emplace( + new_bcast_id, std::unordered_set({new_bcast_id})); + } + } +} + +void ConcretizedBroadcastDomains::handle(Expr* expr) { + IterVisitor::handle(expr); + + // Propagate broadcast origin info from producers to consumers + for (auto producer : ir_utils::filterByType(expr->inputs())) { + std::unordered_set producer_broadcasts; + // This assumes there's no merged broadcast axes between root and rfactor + // domains which is not possible at the moment. If this assumption is ever + // invalidated we would need to manaually propagate root IDs to rfactor IDs. + for (auto producer_id : producer->getMaybeRFactorDomain()) { + if (producer_id->isBroadcast()) { + producer_broadcasts.insert(producer_id); + } + } + if (producer_broadcasts.empty()) { + continue; + } + + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + auto p2c_map = + PairwiseRootDomainMap(producer, consumer) + .mapProducerToConsumer( + producer->domain(), consumer->domain(), producer_broadcasts); + for (const auto& kv : p2c_map) { + auto p_id = kv.first; + auto c_id = kv.second; + const bool is_concretized = !c_id->isBroadcast(); + auto it = broadcast_origin_map_.find(p_id); + TORCH_INTERNAL_ASSERT( + it != broadcast_origin_map_.end(), + "Broadcast origin info not found for producer broadcast domain: ", + p_id->toString(), + " of ", + producer->toString()); + const auto& producer_origins = it->second; + if (is_concretized) { + // Keep track of all the origin domains as concretized + for (auto origin : producer_origins) { + // concretized_root_domains_.insert(origin); + markAsConcretized(origin); + } + } else { + // Not concretized yet. Propagate forward the origin info. + auto& consumer_origins = broadcast_origin_map_[c_id]; + for (auto origin : producer_origins) { + consumer_origins.insert(origin); + } + consumer_origins.insert(c_id); + } + } + } + } +} + +void ConcretizedBroadcastDomains::markAsConcretized(IterDomain* root_domain) { + std::deque child_domains({root_domain}); + while (!child_domains.empty()) { + auto child = child_domains.front(); + child_domains.pop_front(); + if (!concretized_domains_.emplace(child).second) { + continue; + } + const auto& child_uses = child->uses(); + for (auto child_use : child_uses) { + for (auto out_id : + ir_utils::filterByType(child_use->outputs())) { + child_domains.push_back(out_id); + } + } + } +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h new file mode 100644 index 00000000000..9dd50e8afc1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! Traverse and collect all concretized broadcast domains. +//! +//! The traversal first initializes the origin map with broadcast +//! domains in input tensors. Then, a new entry is added to the origin +//! map when a broadcast op is encountered during a forward traversal +//! of the given fusion. For non-broadcast ops, mappings are just +//! propagated forward using PairwiseRootDomainMap. +//! +//! When the mapped consumer domain is not broadcast, it means the +//! producer broadcast domain is concretized, and its origin broadcast +//! domains are marked as concretized. +class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor { + public: + void build(Fusion* fusion); + + bool isConcretized(IterDomain* id) const; + + private: + using IterVisitor::handle; + + void handle(BroadcastOp* bop) final; + + void handle(Expr* expr) final; + + void markAsConcretized(IterDomain* root_domain); + + private: + //! Maps each broadcast domain to its original broadcast + //! domains. Their can be multiple original domains due to, e.g., + //! binary ops with broadcast domains in both inputs. + std::unordered_map> + broadcast_origin_map_; + //! Set of all concretized original domains + std::unordered_set concretized_domains_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp index 33651785d43..a8905b4d404 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp @@ -74,7 +74,7 @@ bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) { } // namespace -void TrivialReductionInfo::build(Fusion* fusion, GpuLower* gpu_lower) { +void TrivialReductionInfo::build(Fusion* fusion) { auto used_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(used_vals)) { @@ -99,20 +99,6 @@ void TrivialReductionInfo::build(Fusion* fusion, GpuLower* gpu_lower) { } } } - - buildKir(fusion, gpu_lower); -} - -void TrivialReductionInfo::buildKir(Fusion* fusion, GpuLower* gpu_lower) { - for (auto id : domains_) { - auto kir_trivial_id = gpu_lower->lowerValue(id)->as(); - kir_domains_.insert(kir_trivial_id); - } - - for (auto id : domains_derived_from_root_) { - auto kir_trivial_id = gpu_lower->lowerValue(id)->as(); - kir_domains_derived_from_root_.insert(kir_trivial_id); - } } bool TrivialReductionInfo::isDerived(IterDomain* id) const { @@ -124,15 +110,6 @@ bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const { domains_derived_from_root_.end(); } -bool TrivialReductionInfo::isDerived(kir::IterDomain* id) const { - return kir_domains_.find(id) != kir_domains_.end(); -} - -bool TrivialReductionInfo::isDerivedFromRoot(kir::IterDomain* id) const { - return kir_domains_derived_from_root_.find(id) != - kir_domains_derived_from_root_.end(); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h index c16439ed4f0..9ccbc2f7828 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -13,24 +13,15 @@ namespace jit { namespace fuser { namespace cuda { -class GpuLower; - //! Detect almost all IterDomains that are derived from trivial //! reductons. class TORCH_CUDA_CU_API TrivialReductionInfo { public: - void build(Fusion* fusion, GpuLower* gpu_lower); + void build(Fusion* fusion); bool isDerived(IterDomain* id) const; bool isDerivedFromRoot(IterDomain* id) const; - bool isDerived(kir::IterDomain* id) const; - bool isDerivedFromRoot(kir::IterDomain* id) const; - - private: - //! Convert the sets to KIR sets - void buildKir(Fusion* fusion, GpuLower* gpu_lower); - private: //! IterDomains that are derived only from trivial //! reductons. Included domains are not limited to reduction axes as @@ -48,9 +39,6 @@ class TORCH_CUDA_CU_API TrivialReductionInfo { //! trivial reductions. These domains do not need to manifest as //! for-loops. std::unordered_set domains_derived_from_root_; - - std::unordered_set kir_domains_; - std::unordered_set kir_domains_derived_from_root_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 08f91ba59bd..c4f926131a8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include #include @@ -22,8 +20,7 @@ namespace { // Provide a new for loop matching the one provided kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - const auto new_loop = ir_builder.create(for_loop); + const auto new_loop = IrBuilder::create(for_loop); for (auto expr : for_loop->body().exprs()) { if (auto nested_for_loop = dynamic_cast(expr)) { expr = cloneLoopNest(nested_for_loop); @@ -35,20 +32,20 @@ kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) { // Returns true if expr is an expression that initializes a reduction // buffer. -bool isReductionInitExpr(const kir::Expr* expr) { +bool isReductionInitExpr(const Expr* expr) { // False if its output isn't a TensorView - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } // False if it doesn't have any reduction axis - const auto out_tv = expr->outputs()[0]->as(); + const auto out_tv = expr->outputs()[0]->as(); if (!out_tv->domain()->hasReduction()) { return false; } // False if it has have TensorView inputs as initialization should // never use TensorViews const auto tv_filter_inp_view = - ir_utils::filterByType(expr->inputs()); + ir_utils::filterByType(expr->inputs()); if (tv_filter_inp_view.begin() != tv_filter_inp_view.end()) { return false; } @@ -57,28 +54,27 @@ bool isReductionInitExpr(const kir::Expr* expr) { } // namespace -void UnrollPass::handle(kir::Expr* expr) { - if (ir_utils::isTVOp(expr)) { +void UnrollPass::handle(Expr* expr) { + if (ir_utils::isTvOp(expr)) { // If tv op, predicate it - const auto out_tv = ir_utils::getTVOutput(expr); + const auto out_tv = ir_utils::getTvOutput(expr); const bool should_predicate = !for_loops_.empty() || - out_tv->memoryType() == MemoryType::Global || - out_tv->memoryType() == MemoryType::Shared; + out_tv->getMemoryType() == MemoryType::Global || + out_tv->getMemoryType() == MemoryType::Shared; if (!should_predicate) { return; } - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); const auto thread_pred = isReductionInitExpr(expr) - ? ir_builder.trueVal() - : GpuLower::current()->threadPredMap().getPredicate(out_tv->fuserTv()); + ? GpuLower::current()->kernel()->trueVal() + : GpuLower::current()->threadPredMap().getPredicate(out_tv); // When this expr is in an unswitched block, only attach the // thread predicate to the expr as thread predicates are not // grouped to the unswitch predicate. kir::Predicate* thread_pred_expr = nullptr; if (unswitched_loop_) { - thread_pred_expr = ir_builder.create(thread_pred); + thread_pred_expr = IrBuilder::create(thread_pred); } non_trivial_pred_found_ = true; @@ -95,7 +91,7 @@ void UnrollPass::handle(kir::Expr* expr) { if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) { const auto write_pred = unswitched_loop_ ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::ReductionWrite, expr, thread_pred); expr->setWritePredicate(write_pred); } @@ -105,7 +101,7 @@ void UnrollPass::handle(kir::Expr* expr) { if (ir_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { const auto pred = unswitched_loop_ ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::Inline, expr, thread_pred); expr->setPredicate(pred); return; @@ -116,28 +112,28 @@ void UnrollPass::handle(kir::Expr* expr) { if (!unswitched_loop_ && std::any_of( for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) { - return fl->iter_domain()->parallelType() == + return fl->iter_domain()->getParallelType() == ParallelType::Vectorize; })) { - pred = ir_builder.create(PredicateType::Vectorize); + pred = IrBuilder::create(PredicateType::Vectorize); } if (pred == nullptr) { pred = unswitched_loop_ ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::Inline, expr, thread_pred); } // If we need a predicate, put expr inside an if then else - kir::IfThenElse* inline_ite = ir_builder.create(pred); + kir::IfThenElse* inline_ite = IrBuilder::create(pred); if (for_loops_.empty()) { // Special handling for top level output expressions that still // need predicates. One motivating example is a reduction op that // reduces to a scalar (issue #491) - expr_replacement_map_.insert({expr, inline_ite}); + kir::ExprMutator::registerReplace(expr, inline_ite, nullptr); } else { - for_loops_.back()->body().insert_before(expr, inline_ite); - for_loops_.back()->body().erase(expr); + kir::ExprMutator::registerReplace( + expr, inline_ite, &for_loops_.back()->body()); } inline_ite->thenBody().push_back(expr); } else if (auto for_loop = dynamic_cast(expr)) { @@ -150,8 +146,8 @@ void UnrollPass::handle(kir::Expr* expr) { void UnrollPass::handle(kir::ForLoop* fl) { // Setup for loop scoping const bool is_unroll = - fl->iter_domain()->parallelType() == ParallelType::Unroll || - fl->iter_domain()->parallelType() == ParallelType::Unswitch; + fl->iter_domain()->getParallelType() == ParallelType::Unroll || + fl->iter_domain()->getParallelType() == ParallelType::Unswitch; // If we're not looking for an unroll loop, or didn't find one, process as // normal. @@ -172,10 +168,9 @@ void UnrollPass::handle(kir::ForLoop* fl) { return; } - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto unroll_pred = ir_builder.create(fl); + auto unroll_pred = IrBuilder::create(fl); - kir::IfThenElse* unroll_ite = ir_builder.create(unroll_pred); + kir::IfThenElse* unroll_ite = IrBuilder::create(unroll_pred); // Get the loop nest for the unrolled path kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl); @@ -199,12 +194,18 @@ void UnrollPass::handle(kir::ForLoop* fl) { handle(inlined_loop); look_for_unroll_ = true; if (!non_trivial_pred_found_) { - expr_replacement_map_.insert({fl, inlined_loop}); + kir::ExprMutator::registerReplace( + fl, + inlined_loop, + for_loops_.empty() ? nullptr : &for_loops_.back()->body()); } else { if (!canOmitElseClause(fl)) { unroll_ite->elseBody().push_back(inlined_loop); } - expr_replacement_map_.insert({fl, unroll_ite}); + kir::ExprMutator::registerReplace( + fl, + unroll_ite, + for_loops_.empty() ? nullptr : &for_loops_.back()->body()); } } @@ -221,14 +222,14 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { // If there's any expression that requires barrier // synchronization, the else part can't be omitted for (auto expr : loop->body().exprs()) { - if (expr->isA()) { + if (expr->isA()) { const ParallelTypeBitmap domains = pred_map.getParallelBroadcastDomains( - expr->outputs()[0]->as()->fuserTv()); + expr->outputs()[0]->as()); if (domains.any()) { return false; } - } else if (expr->isA() || expr->isA()) { - auto td = ir_utils::getTVOutput(expr)->domain(); + } else if (expr->isA() || expr->isA()) { + auto td = ir_utils::getTvOutput(expr)->domain(); if (td->hasBlockReduction() || td->hasGridReduction()) { return false; } @@ -238,14 +239,14 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { // unswitch predicate is sufficient. // When the loop stop is the same as the extent of its IterDomain, // the per-thread visit count is guaranteed to be one at most (see - // CudaKernelGenerator::visit(kir::ForLoop*) as well. Also, when a + // CudaKernelGenerator::handle(kir::ForLoop*) as well. Also, when a // loop is vectorized (not misaligned), the count must be one at // most. Even if not parallelized nor vectoirzed, it is also // sufficient if the loop stop is in fact one. bool visit_once = false; auto id = loop->iter_domain(); if ((id->isThread() && (loop->stop() == id->extent())) || - id->parallelType() == ParallelType::Vectorize) { + id->getParallelType() == ParallelType::Vectorize) { visit_once = true; } if (!visit_once) { @@ -273,30 +274,18 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { } // Generate the loop nest structure and place it in lowered_exprs -UnrollPass::UnrollPass(const std::vector& exprs) { +UnrollPass::UnrollPass(const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::computeMap"); - - // Run through loop nests and further lower the expressions - for (auto* expr : exprs) { - handle(expr); - } + kir::ExprMutator::traverseAndInsert(exprs); } -std::vector UnrollPass::runPass( +std::vector UnrollPass::runPass( Fusion* fusion, - const std::vector& exprs) { + const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::runPass"); UnrollPass unroll_pass(exprs); - - std::vector mutated_exprs; - mutated_exprs.reserve(exprs.size()); - for (auto expr : exprs) { - mutated_exprs.push_back( - ir_utils::applyReplacements(unroll_pass.replacementMap(), expr)); - } - - return mutated_exprs; + return unroll_pass.exprs_; } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index bec4966dd94..14725c405b7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include #include #include #include @@ -51,33 +52,32 @@ namespace cuda { //! predicate still in the inner most loop, making sure that we cover edges and //! corners. //! -class TORCH_CUDA_CU_API UnrollPass { +class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { public: // Take the incoming exprs and run loop unrolling, returning the new IR - static std::vector runPass( + static std::vector runPass( Fusion* fusion, - const std::vector& exprs); + const std::vector& exprs); static bool canOmitElseClause(kir::ForLoop* fl); private: // Generate the for Expr replacement map - UnrollPass(const std::vector& exprs); + UnrollPass(const std::vector& exprs); - const std::unordered_map& replacementMap() const { + const std::unordered_map& replacementMap() const { return expr_replacement_map_; } - void handle(kir::ForLoop* fl); + using OptOutDispatch::handle; - void handle(kir::Expr* expr); + void handle(kir::ForLoop* fl) final; + + void handle(Expr* expr) final; private: // We will track which loops in the incoming IR will be replaced and by what - std::unordered_map expr_replacement_map_; - - // Keep all for loops conveniently to make unrolling easier - std::vector for_loops_; + std::unordered_map expr_replacement_map_; // keep track if we're within an unrolled loop bool look_for_unroll_ = true; diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 5d015c450d9..ba2f618efae 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include #include @@ -23,38 +21,14 @@ namespace cuda { namespace scope_utils { -std::vector getLoops(kir::Expr* scope) { - std::vector loops; - while (scope != nullptr) { - if (auto loop = dynamic_cast(scope)) { - loops.push_back(loop); - } - scope = scope->parentScope(); - } - std::reverse(loops.begin(), loops.end()); - return loops; -} - -void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr) { - if (auto ite = dynamic_cast(scope)) { - ite->thenBody().insert_before(ref, expr); - } else if (auto for_loop = dynamic_cast(scope)) { - for_loop->body().insert_before(ref, expr); - } else { - TORCH_INTERNAL_ASSERT(false, "Unexpected scope expression"); - } -} - //! Create an **empty** Forloop and copy the metadata. -kir::ForLoop* cloneForLoop(kir::IrBuilder& ir_builder, kir::ForLoop* for_loop) { - return ir_builder.create(for_loop); +kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop) { + return IrBuilder::create(for_loop); } //! Create an **empty** IfThenElse and copy the metadata. -kir::IfThenElse* cloneIfThenElse( - kir::IrBuilder& ir_builder, - kir::IfThenElse* ite) { - return ir_builder.create(ite->predicate()); +kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite) { + return IrBuilder::create(ite->predicate()); } } // namespace scope_utils @@ -103,17 +77,18 @@ std::vector iterDomainInputsOfOrderedAs( } bool isTV(const Val* val) { - return val->getValType().value() == ValType::TensorView; + return val->getValType().value() == ValType::TensorView || + val->getValType().value() == ValType::TensorIndex; } // Check if we're a TensorView op that we can generate code for. -bool isTVOp(const Expr* expr) { +bool isTvOp(const Expr* expr) { if (std::any_of( expr->outputs().begin(), expr->outputs().end(), [](Val* v) { return isTV(v); }) && - (expr->getExprType().value() == ExprType::BinaryOp || - expr->getExprType().value() == ExprType::UnaryOp || + (expr->getExprType().value() == ExprType::UnaryOp || + expr->getExprType().value() == ExprType::BinaryOp || expr->getExprType().value() == ExprType::TernaryOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::WelfordOp || @@ -121,28 +96,26 @@ bool isTVOp(const Expr* expr) { expr->getExprType().value() == ExprType::TransposeOp || expr->getExprType().value() == ExprType::ShiftOp || expr->getExprType().value() == ExprType::GatherOp || - expr->getExprType().value() == ExprType::ViewOp)) { + expr->getExprType().value() == ExprType::ViewOp || + expr->getExprType().value() == ExprType::GridReduction || + expr->getExprType().value() == ExprType::GridBroadcast || + expr->getExprType().value() == ExprType::GridWelford)) { return true; } return false; } -bool isTVOp(const kir::Expr* expr) { - const auto& outputs = expr->outputs(); - return outputs.size() >= 1 && outputs[0]->isA(); -} - -kir::TensorView* getTv(kir::Val* val) { - if (auto tv = dynamic_cast(val)) { - return tv; - } else if (auto ti = dynamic_cast(val)) { - return ti->view(); +TensorView* getTv(Val* val) { + if (val->isA()) { + return val->as(); + } else if (val->isA()) { + return val->as()->view(); } return nullptr; } -std::vector getTvs(const std::vector& vals) { - std::vector tvs; +std::vector getTvs(const std::vector& vals) { + std::vector tvs; for (auto val : vals) { auto tv = ir_utils::getTv(val); if (tv) { @@ -152,32 +125,7 @@ std::vector getTvs(const std::vector& vals) { return tvs; } -kir::TensorView* asTv(kir::Val* val) { - auto tv = getTv(val); - TORCH_INTERNAL_ASSERT(tv != nullptr, "Neigher TensorView nor TensorIndex"); - return tv; -} - -std::vector asTvs(const std::vector vals) { - std::vector tvs; - for (auto val : vals) { - auto tv = ir_utils::asTv(val); - tvs.emplace_back(tv); - } - return tvs; -} - -// TODO: why do we assume there's a single TV output? -TensorView* getTVOutput(const Expr* expr) { - for (auto out : expr->outputs()) { - if (out->getValType().value() == ValType::TensorView) { - return out->as(); - } - } - return nullptr; -} - -kir::TensorView* getTVOutput(const kir::Expr* expr) { +TensorView* getTvOutput(const Expr* expr) { for (auto out : expr->outputs()) { if (auto tv = getTv(out)) { return tv; @@ -193,25 +141,20 @@ bool isScalarOp(const Expr* expr) { return true; } -Expr* asExpr(Statement* stmt) { - TORCH_INTERNAL_ASSERT(stmt->isExpr()); - return stmt->as(); -} - -TensorView* asTV(Val* val) { - TORCH_INTERNAL_ASSERT(isTV(val)); - return val->as(); -} - bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { - if (!isTVOp(expr)) { + if (!isTvOp(expr)) { + return false; + } + + if (!(expr->isA() || expr->isA() || + expr->isA() || expr->isA() || + expr->isA() || expr->isA())) { return false; } - auto tv = getTVOutput(expr); + auto tv = getTvOutput(expr); - if ((expr->isA() || expr->isA()) && - (tv->hasBlockReduction() || tv->hasGridReduction())) { + if (tv->hasBlockReduction() || tv->hasGridReduction()) { return true; } else if (expr->isA()) { const ParallelTypeBitmap pt_map = @@ -222,64 +165,22 @@ bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { return false; } -bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map) { - if (expr->isA() || expr->isA() || - expr->isA() || expr->isA() || - expr->isA() || expr->isA()) { - auto fuser_tv = getTVOutput(expr)->fuserTv(); - auto fuser_expr = fuser_tv->definition(); - TORCH_INTERNAL_ASSERT(fuser_expr != nullptr); - return hasBlockSync(fuser_expr, pred_map); - } - - return false; -} - -kir::Expr* applyReplacements( - const std::unordered_map& expr_replacement_map, - kir::Expr* expr) { - auto handle_scope = [&](kir::Scope& scope) { - for (const auto i : c10::irange(scope.size())) { - scope[i] = applyReplacements(expr_replacement_map, scope[i]); - } - }; - - const auto it = expr_replacement_map.find(expr); - if (it != expr_replacement_map.end()) { - return it->second; - } else { - if (auto for_loop = dynamic_cast(expr)) { - handle_scope(for_loop->body()); - } else if (auto ite = dynamic_cast(expr)) { - handle_scope(ite->thenBody()); - handle_scope(ite->elseBody()); - } - return expr; - } -} - -c10::optional getMaybeWarpReductionDim( - const kir::ReductionOp* node) { - auto kir_tv = ir_utils::getTVOutput(node); - if (!kir_tv) { +c10::optional getMaybeWarpReductionDim(const ReductionOp* node) { + auto tv_out = getTv(node->out()); + if (tv_out == nullptr) { return c10::nullopt; } - auto fuser_reduction = kir_tv->fuserTv()->definition()->as(); - return getMaybeWarpReductionDim(fuser_reduction); -} -c10::optional getMaybeWarpReductionDim(const ReductionOp* node) { - auto fuser_tv_out = node->out()->as(); - auto fuser_tv_in = node->in()->as(); + auto tv_in = getTv(node->in()); // only support reducing to registers for now. - if (fuser_tv_in->getMemoryType() != MemoryType::Local || - fuser_tv_out->getMemoryType() != MemoryType::Local) { + if (tv_in->getMemoryType() != MemoryType::Local || + tv_out->getMemoryType() != MemoryType::Local) { return c10::nullopt; } IterDomain* reduction_on_xdim = nullptr; - for (auto id : fuser_tv_out->domain()->domain()) { + for (auto id : tv_out->domain()->domain()) { // Currently warp reduction only allows // serial and block.x parallel reductions if (id->isReduction() && id->isParallelized()) { @@ -302,7 +203,7 @@ c10::optional getMaybeWarpReductionDim(const ReductionOp* node) { return c10::optional(reduction_on_xdim); } - if (reduction_on_xdim->extent()->isConstScalar()) { + if (reduction_on_xdim->extent()->isConst()) { auto extent_value = reduction_on_xdim->extent()->getInt().value(); if (extent_value % at::cuda::warp_size() == 0) { return c10::optional(reduction_on_xdim); @@ -329,22 +230,22 @@ bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis) { }); } -std::unordered_map getParallelDomains( - kir::Val* val) { - kir::TensorView* kir_tv = nullptr; - if (val->isA()) { - kir_tv = val->as(); +std::unordered_map getParallelDomains( + Val* val) { + TensorView* tv = nullptr; + if (val->isA()) { + tv = val->as(); } else if (val->isA()) { - kir_tv = val->as()->view(); + tv = val->as()->view(); } else { TORCH_INTERNAL_ASSERT( false, "Provided val is not TensorIndex or TensorView."); } - std::unordered_map parallel_domains; - for (auto d : kir_tv->domain()->domain()) { + std::unordered_map parallel_domains; + for (auto d : tv->domain()->domain()) { if (d->isThread()) { - parallel_domains.insert(std::make_pair(d->parallelType(), d)); + parallel_domains.insert(std::make_pair(d->getParallelType(), d)); } } return parallel_domains; @@ -354,29 +255,60 @@ std::unordered_map getParallelDomains( namespace loop_utils { -// TODO: Clean this up, Naoya added a mechanism we should be able to reuse. -std::pair getAllocPoint( +BasicAllocInfo getAllocInformation( const TensorView* tv, - const std::vector& loops, + const std::vector& for_loops, const std::unordered_map& id_map, bool use_id_map) { - const auto gpu_lower = GpuLower::current(); + BasicAllocInfo info; + auto gpu_lower = GpuLower::current(); + const auto& loop_map = gpu_lower->caLoopMap(); - // If in global memory, it can be all the way outside the loops. - if (tv->getMemoryType() == MemoryType::Global) { - return {nullptr, 0}; - } + bool outer_alloc_found = false; + + for (auto fl : for_loops) { + if (info.alloc_pos == tv->getComputeAtPosition()) { + break; + } - // Figure out where we want to place alloc/reduction initialization. We want - // outside an unroll loop, or inside our computeAt point. - kir::ForLoop* alloc_loop = nullptr; + if (tv->axis(info.alloc_pos)->isReduction()) { + const auto outputs = FusionGuard::getCurFusion()->getTerminatingOutputs(); + TORCH_INTERNAL_ASSERT( + std::find(outputs.begin(), outputs.end(), tv) != outputs.end(), + "Invalid computeAt of T", + tv->name(), + ". A reducation axis is detected outside computeAt point even though it is not an output tensor."); + break; + } + + auto fl_id = fl->iter_domain(); + + if (fl_id->getParallelType() == ParallelType::Unroll) { + break; + } + + // Shared memory must be allocated outside of unswitched + // domains. See issue #1133. + if (fl_id->getParallelType() == ParallelType::Unswitch && + tv->getMemoryType() == MemoryType::Shared) { + outer_alloc_found = true; + } + + // Assume global memory is allocated at outer most scope. + if (tv->getMemoryType() == MemoryType::Global) { + outer_alloc_found = true; + } - auto loops_it = loops.begin(); - // Look at each axis individually in out's domain - for (const auto tv_i : c10::irange((int64_t)tv->getComputeAtPosition())) { - // Grab the axis ID + // Allocation of a double buffered tensor is placed outside its + // double buffer axis. + if (tv->isDoubleBuffered() && + tv->axis(info.alloc_pos) == + gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) { + outer_alloc_found = true; + } + + auto local_id = tv->axis(info.alloc_pos); - auto local_id = tv->axis(tv_i); if (use_id_map) { auto id_it = id_map.find(local_id); if (id_it != id_map.end()) { @@ -384,52 +316,33 @@ std::pair getAllocPoint( } } - if (gpu_lower->trivialReductionInfo().isDerivedFromRoot(local_id)) { - continue; + if (loop_map.areMapped(local_id, fl_id)) { + info.alloc_pos++; } - auto lowered_local_id = - gpu_lower->lowerValue(local_id)->as(); - loops_it = std::find_if( - loops_it, loops.end(), [&lowered_local_id](const auto& loop) { - return GpuLower::current()->caLoopMap().areMapped( - lowered_local_id, loop->iter_domain()) || - loop->iter_domain()->parallelType() == ParallelType::Unroll; - }); + info.init_for_loop = fl; - TORCH_INTERNAL_ASSERT( - loops_it != loops.end(), - "Could not find all required axes for indexing when trying to index into ", - tv); - if ((*loops_it)->iter_domain()->parallelType() == ParallelType::Unroll) { - return {alloc_loop, tv_i}; + if (!outer_alloc_found) { + info.alloc_for_loop = fl; } - - alloc_loop = *loops_it; - ++loops_it; } - return {alloc_loop, (int64_t)tv->getComputeAtPosition()}; -} - -std::pair getAllocPoint( - const TensorView* tv, - const std::vector& loops) { - return getAllocPoint(tv, loops, {}, false); + return info; } } // namespace loop_utils namespace { -class ReplaceExprInput : public kir::MutableIrVisitor { +class ReplaceExprInput : public OptOutDispatch { public: - static kir::Expr* replace( - kir::Expr* expr, - const std::unordered_map& replacement_map) { + using OptOutDispatch::handle; + static Expr* replace( + Expr* expr, + const std::unordered_map& replacement_map) { ReplaceExprInput replacer(expr, replacement_map); TORCH_INTERNAL_ASSERT(expr != nullptr); - expr->accept(&replacer); + replacer.handle(expr); TORCH_INTERNAL_ASSERT(replacer.replaced_expr_ != nullptr); auto ret_expr = replacer.replaced_expr_; @@ -441,10 +354,10 @@ class ReplaceExprInput : public kir::MutableIrVisitor { return ret_expr; } - static std::vector replace( - const std::vector& scope, - const std::unordered_map& replacement_map) { - std::vector ret_expr; + static std::vector replace( + const std::vector& scope, + const std::unordered_map& replacement_map) { + std::vector ret_expr; ret_expr.reserve(scope.size()); for (auto expr : scope) { @@ -455,20 +368,20 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } private: + // TODO: Replace this with mutator, example of this is done in replace + // symbolic sizes ReplaceExprInput( - kir::Expr* expr, - const std::unordered_map& replacement_map) - : gpu_lower_(GpuLower::current()), - ir_builder_(gpu_lower_->kernel()), - replacement_map_(replacement_map) { + Expr* expr, + const std::unordered_map& replacement_map) + : replacement_map_(replacement_map) { replaced_expr_ = expr; } - c10::optional> - getMaybeInputReplacementMap(kir::Expr* expr) { + c10::optional> getMaybeInputReplacementMap( + Expr* expr) { bool need_replacement = false; - std::unordered_map replaced_val; + std::unordered_map replaced_val; for (auto in : expr->inputs()) { auto replace_it = replacement_map_.find(in); if (replace_it != replacement_map_.end()) { @@ -479,16 +392,15 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } } if (need_replacement) { - return c10::optional>( - replaced_val); + return c10::optional>(replaced_val); } else { return c10::nullopt; } } // IR visitor interface - void visit(kir::ForLoop* for_loop) final { - auto new_for_loop = ir_builder_.create(for_loop); + void handle(kir::ForLoop* for_loop) final { + auto new_for_loop = IrBuilder::create(for_loop); auto replaced_loop_body = replace(for_loop->body().exprs(), replacement_map_); @@ -499,8 +411,8 @@ class ReplaceExprInput : public kir::MutableIrVisitor { replaced_expr_ = new_for_loop; } - void visit(kir::IfThenElse* ite) final { - auto new_ite = ir_builder_.create(ite->predicate()); + void handle(kir::IfThenElse* ite) final { + auto new_ite = IrBuilder::create(ite->predicate()); auto replaced_then_body = replace(ite->thenBody().exprs(), replacement_map_); for (auto new_expr : replaced_then_body) { @@ -516,31 +428,31 @@ class ReplaceExprInput : public kir::MutableIrVisitor { replaced_expr_ = new_ite; } - void visit(kir::UnaryOp* node) final { + void handle(UnaryOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getUnaryOpType(), node->out(), replaced_inputs.value().at(node->in())); } } - void visit(kir::BinaryOp* node) final { + void handle(BinaryOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getBinaryOpType(), node->out(), replaced_inputs.value().at(node->lhs()), replaced_inputs.value().at(node->rhs())); } } - void visit(kir::TernaryOp* node) final { + void handle(TernaryOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getTernaryOpType(), node->out(), replaced_inputs.value().at(node->in1()), replaced_inputs.value().at(node->in2()), @@ -548,29 +460,31 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } } - void visit(kir::ReductionOp* node) final { + void handle(ReductionOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getReductionOpType(), node->init(), node->out(), replaced_inputs.value().at(node->in())); } } - void visit(kir::BroadcastOp* node) final { + void handle(BroadcastOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->out(), replaced_inputs.value().at(node->in())); + replaced_expr_ = IrBuilder::create( + node->out(), + replaced_inputs.value().at(node->in()), + node->getBroadcastDimFlags()); } } - void visit(kir::WelfordOp* node) final { + void handle(WelfordOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( + replaced_expr_ = IrBuilder::create( node->outAvg(), node->outVar(), node->outN(), @@ -584,17 +498,15 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } private: - GpuLower* gpu_lower_; - kir::IrBuilder ir_builder_; - kir::Expr* replaced_expr_ = nullptr; - const std::unordered_map& replacement_map_; + Expr* replaced_expr_ = nullptr; + const std::unordered_map& replacement_map_; }; } // namespace -std::vector replaceInputsInExpr( - const std::vector& exprs, - const std::unordered_map& replacement_map) { +std::vector replaceInputsInExpr( + const std::vector& exprs, + const std::unordered_map& replacement_map) { return ReplaceExprInput::replace(exprs, replacement_map); } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 1c8a0df5cd7..4ed6c25e731 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -1,7 +1,7 @@ #pragma once -#include +#include #include #include @@ -19,27 +19,15 @@ namespace cuda { class ThreadPredicateMap; -using IterDomainMap = std::unordered_map; +using IterDomainMap = std::unordered_map; namespace scope_utils { -//! Returns the list of nesting loops starting at `scope` -// Primarily used in indexing, maybe could be moved there -std::vector getLoops(kir::Expr* scope); - -//! Insert expr in scope before ref -//! -//! \warning for kir::IfThenElse we implicitly insert in the "then" branch! -//! -void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr); - //! Create an **empty** Forloop and copy the metadata. -kir::ForLoop* cloneForLoop(kir::IrBuilder& ir_builder, kir::ForLoop* for_loop); +kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop); //! Create an **empty** IfThenElse and copy the metadata. -kir::IfThenElse* cloneIfThenElse( - kir::IrBuilder& ir_builder, - kir::IfThenElse* ite); +kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite); } // namespace scope_utils @@ -74,107 +62,80 @@ std::vector iterDomainInputsOfOrderedAs( const std::vector& of, const std::vector& order); +// Returns if Val is a TensorView or TensorIndex bool isTV(const Val* const); -TORCH_CUDA_CU_API bool isTVOp(const Expr*); - -bool isTVOp(const kir::Expr* expr); - -TensorView* getTVOutput(const Expr*); -kir::TensorView* getTVOutput(const kir::Expr*); - -bool isScalarOp(const Expr*); - -// TODO(kir): remove -Expr* asExpr(Statement*); +// Returns is Expr is a TensorView or TensorIndex Expr. +TORCH_CUDA_CU_API bool isTvOp(const Expr*); -// TODO(kir): Remove in favor of ->as() -TensorView* asTV(Val*); - -//! Get kir::TensorView potentially via kir::TensorIndex. Returns nullptr if -//! cast fails. -kir::TensorView* getTv(kir::Val*); - -//! Get only kir::TensorView potentially via kir::TensorIndex. -std::vector getTvs(const std::vector& vals); - -//! Get kir::TensorView potentially via kir::TensorIndex. Error if cast fails. -kir::TensorView* asTv(kir::Val*); - -//! Get kir::TensorView potentially via kir::TensorIndex. Error if cast fails. -std::vector asTvs(const std::vector& vals); +// Returns the first output of Expr that is a TensorView +TensorView* getTvOutput(const Expr*); bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map); -bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map); - -// expr_replacement_map maps an expression to its replacement. -// -// The applyReplacement function serves two purposes. -// -// 1. If expr is found in expr_replacement_map, return the value for expr key. -// Otherwise, return the original expression. -// -// 2. If a replacement is not found and the expression is a ForLoop or an -// IfThenElse, it modifies the expressions in its scope by running the -// handle_scope function -// -// The handle_scope function iterates over the expressions in the scope. -// For each expression, it updates the expression the value returned by -// applyReplacement. -kir::Expr* applyReplacements( - const std::unordered_map& expr_replacement_map, - kir::Expr* expr); //! Returns the Fuser iterdomain that maps to the thread dimension grouped //! to warps. Returns nullopt if the reduction is not to be lowered to //! a warp reduction. -c10::optional getMaybeWarpReductionDim( - const kir::ReductionOp* node); - c10::optional getMaybeWarpReductionDim(const ReductionOp* node); +bool isScalarOp(const Expr*); + +//! Get TensorView potentially via kir::TensorIndex. Returns nullptr if +//! cast fails. +TensorView* getTv(Val*); + +//! Get only TensorView potentially via kir::TensorIndex. +std::vector getTvs(const std::vector& vals); + //! Return true if axis is derived from a root axis that is an input //! to a CA leaf axis. bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis); -std::unordered_map getParallelDomains( - kir::Val* val); +std::unordered_map getParallelDomains( + Val* val); } // namespace ir_utils namespace loop_utils { -// I wanted to make the tv's in these util functions constant, but that started -// a long const-ness project going into TensorView (making functions const -// there) then into lower_loops where we sort exprs. -// TODO: We should fix this when we have some time. - -// Figure out which loop the allocation needs to be in. Returns nullptr if -// outside the first loop in loops. Also find out which index in tv the -// first dimension that needs to be allocated is. Meaning we need to allocate -// that local axis and above. -// TODO: Only remaining use of this is in index compute, remove use from there, -// or refactor and use in lower_allocation -std::pair getAllocPoint( - const TensorView* tv, - const std::vector& loops, - const std::unordered_map& id_map, - bool use_id_map); +struct BasicAllocInfo { + // The for loop that the initialization of this allocation must be + // placed in, nullptr if not within a loop + kir::ForLoop* init_for_loop = nullptr; + + // Keep track of the actual allocation loop. This can be different + // from init_for_loop only with unswitched shared memory allocations, + // which are moved outer loops to avoid duplicated allocations. This means + // that the alloc position may be outside what's expected. Most applications + // outside lower_allocation is likely looking for init_for_loop which is + // more directly related to how large an allocation is and how it's used. + // (see issue #1133). + kir::ForLoop* alloc_for_loop = nullptr; + + // The allocation position relative to buffer IDs, it could be outside the + // compute at position if it's shared memory with a compute at inside an + // unswitch + size_t alloc_pos = 0; +}; -std::pair getAllocPoint( +// Fill the above allocation struct based on provided information. id_map is +// used if we're looking at a producer tensor but loops on a consumer tensor. +BasicAllocInfo getAllocInformation( const TensorView* tv, - const std::vector& loops); + const std::vector& loops, + const std::unordered_map& id_map = {}, + bool use_id_map = false); } // namespace loop_utils // Replace value pass on Kernel IR. -// Replace each use of any kir::Val* that apears in the given `replacement_map` +// Replace each use of any Val* that apears in the given `replacement_map` // Keeps the predicate carried by each expr // // Warning: Blindly replaces all use based on pointer // Warning: May invalidate indexing if replacing uses of allocated values -std::vector replaceInputsInExpr( - const std::vector& exprs, - const std::unordered_map& replacement_map); +std::vector replaceInputsInExpr( + const std::vector& exprs, + const std::unordered_map& replacement_map); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 0579e44dcd6..25ba76ee71b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -319,7 +318,7 @@ class VectorizeValidator : public OptInDispatch { vector_size, " however, vector sizes only upto and including 16 bytes are supported."); - auto replay_exprs = ExprSort::getExprs(fusion, {v_id}); + auto replay_exprs = StmtSort::getExprs(fusion, {v_id}, false); VectorizeValidator validator(v_id); @@ -463,6 +462,14 @@ void validateParallelizationOfTensor(TensorView* tv) { continue; } + // It doesn't matter if this axis is a non-concretized broadcast + // TODO: merging broadcast and non-broadcast + if (axis->isBroadcast() && + !GpuLower::current()->concretizedBroadcastDomains().isConcretized( + axis)) { + continue; + } + TORCH_INTERNAL_ASSERT( !pt_map.get(ptype), "Multiple use of ", @@ -489,7 +496,7 @@ void validateParallelizationOfTensor(TensorView* tv) { ". The tensor is parallelized with ", predicated_parallel_types.toString(), ", but it's invalid to use the types as the tensor is also predicated with them.", - ", thread prd: ", + ", thread pred: ", thread_pred.limited_types.toString()); } @@ -503,10 +510,10 @@ void validateParallelize(Fusion* fusion) { const auto& loop_map = GpuLower::current()->caLoopMap(); const auto& pred_map = GpuLower::current()->threadPredMap(); - auto exprs = ExprSort::getExprs(fusion); + auto exprs = StmtSort::getExprs(fusion); for (auto expr : exprs) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { continue; } // Validate parallelization of each consumer by itself @@ -630,7 +637,7 @@ namespace { // each tensor that needs to be computed. std::unordered_map> getLiveRangeOffsets( Fusion* fusion) { - auto exprs = ExprSort::getExprs(fusion); + auto exprs = StmtSort::getExprs(fusion); std::unordered_map> map; @@ -760,7 +767,9 @@ void validatePartialSplit(Fusion* fusion) { auto range_info = getLiveRangeOffsets(fusion); for (auto tv : ir_utils::allTvs(fusion)) { - auto exprs = ir_utils::historyOf(tv); + auto exprs = StmtSort::getExprs( + tv->fusion(), + {tv->domain()->domain().begin(), tv->domain()->domain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // When the start and stop offsets are not zero, make sure the // range defined by the split includes the required range to diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 89de85026ee..115df13c322 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include diff --git a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp index eaddf7faea3..630d3128e78 100644 --- a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include #include @@ -18,20 +18,19 @@ namespace { //! and their corresponding allocations class EliminateDeadBroadcastAndAllocate { public: - static std::vector run(const std::vector& exprs) { + static std::vector run(const std::vector& exprs) { EliminateDeadBroadcastAndAllocate dce(exprs); return dce.result_exprs_; } private: - EliminateDeadBroadcastAndAllocate(const std::vector& exprs) - : ir_builder_(GpuLower::current()->kernel()) { + EliminateDeadBroadcastAndAllocate(const std::vector& exprs) { findLiveTvs(exprs); findDeadTvs(); eliminateDeadCode(exprs); } - void findLiveTvs(const std::vector& exprs) { + void findLiveTvs(const std::vector& exprs) { for (auto expr : exprs) { if (auto for_loop = dynamic_cast(expr)) { findLiveTvs(for_loop->body().exprs()); @@ -44,11 +43,10 @@ class EliminateDeadBroadcastAndAllocate { if (auto allocate = dynamic_cast(expr)) { if (allocate->memoryType() == MemoryType::Local) { - if (auto kir_tv = - dynamic_cast(allocate->buffer())) { + if (auto tv = dynamic_cast(allocate->buffer())) { // We know only tvs that we'd want to consider are broadcast outputs - if (kir_tv->fuserTv()->definition()->isA()) { - candidate_tv_set_.insert(kir_tv); + if (tv->definition()->isA()) { + candidate_tv_set_.insert(tv); } } } @@ -72,18 +70,18 @@ class EliminateDeadBroadcastAndAllocate { } } - void eliminateDeadCode(const std::vector& exprs) { + void eliminateDeadCode(const std::vector& exprs) { result_exprs_ = eliminateDeadCodeInScope(exprs); } - bool shouldEliminate(kir::Expr* expr) { + bool shouldEliminate(Expr* expr) { if (auto allocate = dynamic_cast(expr)) { - if (auto buffer_tv = dynamic_cast(allocate->buffer())) { + if (auto buffer_tv = dynamic_cast(allocate->buffer())) { if (dead_tvs_.count(buffer_tv)) { return true; } } - } else if (auto broadcast = dynamic_cast(expr)) { + } else if (auto broadcast = dynamic_cast(expr)) { if (auto out_ti = dynamic_cast(broadcast->out())) { if (dead_tvs_.count(out_ti->view())) { return true; @@ -95,9 +93,8 @@ class EliminateDeadBroadcastAndAllocate { //! Returns a new vector of exprs with dead exprs //! eliminated. - std::vector eliminateDeadCodeInScope( - const std::vector& exprs) { - std::vector result_exprs; + std::vector eliminateDeadCodeInScope(const std::vector& exprs) { + std::vector result_exprs; for (auto expr : exprs) { auto result_expr = expr; @@ -128,7 +125,7 @@ class EliminateDeadBroadcastAndAllocate { // TODO: we will need a kernel_ir cloner to make this // kind of logic re-usable. - auto new_loop = scope_utils::cloneForLoop(ir_builder_, for_loop); + auto new_loop = scope_utils::cloneForLoop(for_loop); for (auto expr : new_loop_body) { new_loop->body().push_back(expr); @@ -143,7 +140,7 @@ class EliminateDeadBroadcastAndAllocate { return nullptr; } - auto new_ite = scope_utils::cloneIfThenElse(ir_builder_, ite); + auto new_ite = scope_utils::cloneIfThenElse(ite); for (auto expr : new_then_body) { new_ite->thenBody().push_back(expr); @@ -155,12 +152,11 @@ class EliminateDeadBroadcastAndAllocate { } private: - std::unordered_set live_tvs_; - std::unordered_set dead_tvs_; - std::unordered_set candidate_tv_set_; + std::unordered_set live_tvs_; + std::unordered_set dead_tvs_; + std::unordered_set candidate_tv_set_; - std::vector result_exprs_; - kir::IrBuilder ir_builder_; + std::vector result_exprs_; }; //! A pass to eliminate redundant parallel broadcasts that are consumers @@ -189,9 +185,9 @@ class EliminateDeadBroadcastAndAllocate { //! //! 3. EliminateDeadBroadcastAndAllocate removes the broadcast ops //! and corresponding allocations if they're un-used after step 2. -class FuseBroadcastWithWarpReduce { +class FuseBroadcastWithWarpReduce : private kir::IrVisitor { public: - static std::vector fuse(const std::vector& exprs) { + static std::vector fuse(const std::vector& exprs) { FuseBroadcastWithWarpReduce fuse_broadcast_map(exprs); const auto replaced_inputs = replaceInputsInExpr(exprs, fuse_broadcast_map.val_replacement_map_); @@ -199,70 +195,50 @@ class FuseBroadcastWithWarpReduce { } private: - FuseBroadcastWithWarpReduce(const std::vector& exprs) { + FuseBroadcastWithWarpReduce(const std::vector& exprs) { // open stack space for global scope - // The scope stack for kir_tv_to_allocate wouldn't be needed + // The scope stack for tv_to_allocate wouldn't be needed // if the allocations are guaranteed to be once and unique, // which can currently be assumed but this pass tries not // to rely on this assumption. - running_kir_tv_to_allocate_map_.emplace_back( - std::make_unique< - std::unordered_map>()); + running_tv_to_allocate_map_.emplace_back( + std::make_unique>()); running_visible_allocation_stack_.emplace_back( std::make_unique>()); - - for (auto expr : exprs) { - handle(expr); - } + kir::IrVisitor::handle(exprs); } - void handle(kir::Expr* expr) { - if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - return; - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - return; - } - - // Process expr inputs if needs replacement - for (auto inp : expr->inputs()) { - if (auto input_ti = dynamic_cast(inp)) { - auto replace = findMaybeReplacedTensorIndex(input_ti); - if (replace.has_value()) { - val_replacement_map_[input_ti] = replace.value(); + void handle(Expr* expr) final { + if (ir_utils::isTvOp(expr)) { + // Process expr inputs if needs replacement + for (auto inp : expr->inputs()) { + if (auto input_ti = dynamic_cast(inp)) { + auto replace = findMaybeReplacedTensorIndex(input_ti); + if (replace.has_value()) { + val_replacement_map_[input_ti] = replace.value(); + } } } } - - // Handle reduction definitions - if (auto reduction = dynamic_cast(expr)) { - handle(reduction); - } else if (auto broadcast = dynamic_cast(expr)) { - handle(broadcast); - } else if (auto allocate = dynamic_cast(expr)) { - handle(allocate); - } } - bool openLoopNestLevel(kir::IterDomain* id) { - if (id->isThread() || id->parallelType() == ParallelType::Unswitch) { + bool openLoopNestLevel(IterDomain* id) { + if (id->isThread() || id->getParallelType() == ParallelType::Unswitch) { return false; } - if (id->parallelType() == ParallelType::Serial || - id->parallelType() == ParallelType::Unroll) { + if (id->getParallelType() == ParallelType::Serial || + id->getParallelType() == ParallelType::Unroll) { return !id->isBroadcast(); } return true; } - void handle(kir::ForLoop* for_loop) { + void handle(kir::ForLoop* for_loop) final { // Keep track of visible reduction outputs bool open_nest_level = openLoopNestLevel(for_loop->iter_domain()); if (open_nest_level) { - running_kir_tv_to_allocate_map_.emplace_back( - std::make_unique< - std::unordered_map>()); + running_tv_to_allocate_map_.emplace_back( + std::make_unique>()); running_visible_allocation_stack_.emplace_back( std::make_unique>()); } @@ -270,12 +246,12 @@ class FuseBroadcastWithWarpReduce { handle(expr); } if (open_nest_level) { - running_kir_tv_to_allocate_map_.pop_back(); + running_tv_to_allocate_map_.pop_back(); running_visible_allocation_stack_.pop_back(); } } - void handle(kir::IfThenElse* ite) { + void handle(kir::IfThenElse* ite) final { running_visible_allocation_stack_.emplace_back( std::make_unique>()); for (auto expr : ite->thenBody().exprs()) { @@ -292,15 +268,14 @@ class FuseBroadcastWithWarpReduce { //! Place this allocate on the list of currently visible allocations, //! organized by loop nest level. - void handle(kir::Allocate* allocate) { + void handle(kir::Allocate* allocate) final { if (allocate->memoryType() != MemoryType::Local) { return; } - if (auto kir_tv = dynamic_cast(allocate->buffer())) { - auto fuser_tv = kir_tv->fuserTv(); - if (fuser_tv->definition()) { - if (fuser_tv->definition()->isA() || - fuser_tv->definition()->isA()) { + if (auto tv = dynamic_cast(allocate->buffer())) { + if (tv->definition()) { + if (tv->definition()->isA() || + tv->definition()->isA()) { running_visible_allocation_stack_.back()->push_back(allocate); } } @@ -311,18 +286,18 @@ class FuseBroadcastWithWarpReduce { //! returns the replaced TensorIndex if so. c10::optional findMaybeReplacedTensorIndex( kir::TensorIndex* tensor_index) { - auto kir_tv = tensor_index->view(); - auto tensor_index_it = running_tv_replacement_map_.find(kir_tv); + auto tv = tensor_index->view(); + auto tensor_index_it = running_tv_replacement_map_.find(tv); if (tensor_index_it != running_tv_replacement_map_.end()) { return tensor_index_it->second; } return c10::nullopt; } - //! Iteratve backwards on the currently visible loop scopes + //! Iterate backwards on the currently visible loop scopes //! and find the first allocation corresponding to the //! given tv. - kir::Allocate* getActiveAllocateFor(kir::TensorView* tv) { + kir::Allocate* getActiveAllocateFor(TensorView* tv) { for (auto frame_it = running_visible_allocation_stack_.rbegin(); frame_it != running_visible_allocation_stack_.rend(); frame_it++) { @@ -340,19 +315,10 @@ class FuseBroadcastWithWarpReduce { return nullptr; } - Expr* getFuserTVExpr(kir::Expr* expr) { - auto out = expr->outputs()[0]; - auto out_ti = dynamic_cast(out); - if (!out_ti) { - return nullptr; - } - return out_ti->view()->fuserTv()->definition(); - } - - bool isOpInputRegisterTV(kir::Expr* expr) { + bool isOpInputRegisterTV(Expr* expr) { for (auto inp : expr->inputs()) { if (auto inp_ti = dynamic_cast(inp)) { - if (inp_ti->view()->memoryType() != MemoryType::Local) { + if (inp_ti->view()->getMemoryType() != MemoryType::Local) { return false; } } @@ -361,10 +327,10 @@ class FuseBroadcastWithWarpReduce { return true; } - bool isOpOutputRegisterTV(kir::Expr* expr) { + bool isOpOutputRegisterTV(Expr* expr) { for (auto out : expr->outputs()) { if (auto out_ti = dynamic_cast(out)) { - if (out_ti->view()->memoryType() != MemoryType::Local) { + if (out_ti->view()->getMemoryType() != MemoryType::Local) { return false; } } @@ -374,8 +340,8 @@ class FuseBroadcastWithWarpReduce { } //! Updates map of serially visible reduction tvs, see comment on - //! running_kir_tv_to_allocate_map_. - void handle(kir::ReductionOp* reduction) { + //! running_tv_to_allocate_map_. + void handle(ReductionOp* reduction) final { if (!isOpOutputRegisterTV(reduction)) { return; } @@ -386,11 +352,11 @@ class FuseBroadcastWithWarpReduce { // keep track of which reduction buffer this expr writes into auto reduction_allocate = getActiveAllocateFor(reduction_ti_out->view()); - running_kir_tv_to_allocate_map_.back()->operator[]( - reduction_ti_out->view()) = reduction_allocate; + running_tv_to_allocate_map_.back()->operator[](reduction_ti_out->view()) = + reduction_allocate; } - void handle(kir::BroadcastOp* broadcast) { + void handle(BroadcastOp* broadcast) final { if (!isOpInputRegisterTV(broadcast) || !isOpOutputRegisterTV(broadcast)) { return; } @@ -400,9 +366,9 @@ class FuseBroadcastWithWarpReduce { //! Detects if this broadcast can be fused with the producer reduction. //! adds the output of broadcast to replacement map if all above mentioned //! conditions check. - void tryAddOutputToReplaceMap(kir::BroadcastOp* broadcast) { + void tryAddOutputToReplaceMap(BroadcastOp* broadcast) { if (auto in_ti = dynamic_cast(broadcast->in())) { - if (!in_ti->view()->fuserTv()->definition()->isA()) { + if (!in_ti->view()->definition()->isA()) { return; } auto out_ti = broadcast->out()->as(); @@ -410,15 +376,14 @@ class FuseBroadcastWithWarpReduce { // check reduction-broadcast mapping: if (!canFuseBroadcastWithWarpReduction( - out_tv->fuserTv()->definition()->as())) { + out_tv->definition()->as())) { return; } // check buffers are size-1 auto reduction_allocate_it = - running_kir_tv_to_allocate_map_.back()->find(in_ti->view()); - if (reduction_allocate_it == - running_kir_tv_to_allocate_map_.back()->end()) { + running_tv_to_allocate_map_.back()->find(in_ti->view()); + if (reduction_allocate_it == running_tv_to_allocate_map_.back()->end()) { // The producer reduction is not in the serially visible scope, // as defined in openLoopNestLevel. There still could be some // cases that we could fuse but disabled for simplicity. @@ -444,7 +409,7 @@ class FuseBroadcastWithWarpReduce { return; } - // Write the kir_tv in to the replacement map + // Write the tv in to the replacement map // so the future uses of this tv will put // the tensorIndex's in the actual replacement map. running_tv_replacement_map_[out_tv] = in_ti; @@ -515,7 +480,7 @@ class FuseBroadcastWithWarpReduce { //! could need some extension for more precise scope based analysis in the //! future especially if we have more complex IfThenElse blocks than //! predicates and unroll. - std::unordered_map + std::unordered_map running_tv_replacement_map_; //! Keeps track of the allocated buffers that the exprs will write/read @@ -531,21 +496,20 @@ class FuseBroadcastWithWarpReduce { //! visibility on the generated kernel. The model of IfThenElse assumes the //! only ITE's we have are predicates and unrolls, which might need to be //! more precise. - std::vector< - std::unique_ptr>> - running_kir_tv_to_allocate_map_; + std::vector>> + running_tv_to_allocate_map_; //! This map is the final output of this pass and a val replacement map will //! be run using //! it. All keys and values are TensorIndex's, and before this pass each //! TensorIndex is uniquely generated by lower_index pass for each access of - //! a kir_tv. - std::unordered_map val_replacement_map_; + //! a tv. + std::unordered_map val_replacement_map_; }; } // namespace -std::vector fuseWarpReduce(const std::vector exprs) { +std::vector fuseWarpReduce(const std::vector exprs) { return FuseBroadcastWithWarpReduce::fuse(exprs); } diff --git a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h index 785c0b59122..7480809c7dc 100644 --- a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h +++ b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h @@ -13,7 +13,7 @@ struct WarpPaddedParallelInfo { bool has_warp_reduction = false; }; -std::vector fuseWarpReduce(const std::vector exprs); +std::vector fuseWarpReduce(const std::vector exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index ee1bea81535..0f5967c004d 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -141,6 +141,25 @@ class CudaFusionManager { int32_t next_unique_id_ = 0; }; +// Mark string attribute in alias-copy nodes to enable its implementation +// in the fallback path. +void enableAliasCopyNodes(const std::shared_ptr& graph, Block* block) { + static std::unordered_set alias_copy_op( + {prim::view_copy, + prim::reshape_copy, + prim::squeeze_copy, + prim::unsqueeze_copy}); + + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + enableAliasCopyNodes(graph, b); + } + if (alias_copy_op.find(n->kind()) != alias_copy_op.end()) { + n->s_(attr::name, "CudaFusionGroup"); + } + } +} + } // namespace void compileCudaFusionGroup(Node* fusion_node) { @@ -194,6 +213,7 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { // copying graph here since we are eliminating shape information; auto copied_graph = fusion_node->g(attr::Subgraph)->copy(); EraseShapeInformation(copied_graph); + enableAliasCopyNodes(copied_graph, copied_graph->block()); InterpreterState{Code(copied_graph, "fallback_cuda_fuser")}.run(stack); }; diff --git a/torch/csrc/jit/codegen/cuda/manager.h b/torch/csrc/jit/codegen/cuda/manager.h index 39c97478eff..4b725cd80bc 100644 --- a/torch/csrc/jit/codegen/cuda/manager.h +++ b/torch/csrc/jit/codegen/cuda/manager.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include /* diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 8d13f1e299e..c24e444eb56 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -10,143 +11,177 @@ namespace jit { namespace fuser { namespace cuda { -// MUTATE FUNCTIONS FOR VALS +void OptOutMutator::mutate(Statement* s) { + Statement::mutatorDispatch(this, s); +} + +void OptOutMutator::mutate(Expr* e) { + Expr::mutatorDispatch(this, e); +} + +void OptOutMutator::mutate(Val* v) { + Val::mutatorDispatch(this, v); +} + +void OptOutMutator::registerMutation(Val* val, Val* mutation) { + bool val_is_ns = val->vtype() == ValType::NamedScalar; + bool mutation_is_ns = mutation->vtype() == ValType::NamedScalar; + bool val_is_scalar = val->vtype() == ValType::Scalar; + bool mutation_is_scalar = mutation->vtype() == ValType::Scalar; + TORCH_INTERNAL_ASSERT( + mutation->dtype() == val->dtype() && + (mutation->vtype() == val->vtype() || + ((val_is_ns && mutation_is_scalar) || + (mutation_is_ns && val_is_scalar))), + "Mutations are not allowed to change types, tried to go from: (", + val->vtype(), + ", ", + val->dtype(), + ") to: (", + mutation->vtype(), + ", ", + mutation->dtype(), + ")"); + mutations[val] = mutation; +} + +void OptOutMutator::mutate(Bool* b) {} + +void OptOutMutator::mutate(Double* d) {} -Statement* OptOutMutator::mutate(IterDomain* id) { - Val* start = mutateAsVal(id->start())->asVal(); - Val* extent = mutateAsVal(id->extent())->asVal(); - Val* stop_offset = mutateAsVal(id->stopOffset())->asVal(); +void OptOutMutator::mutate(Int* i) {} + +void OptOutMutator::mutate(NamedScalar* ns) {} + +void OptOutMutator::mutate(IterDomain* id) { + Val* start = maybeMutated(id->start()); + Val* extent = maybeMutated(id->extent()); + Val* stop_offset = maybeMutated(id->stopOffset()); if (start->sameAs(id->start()) && extent->sameAs(id->extent()) && stop_offset->sameAs(id->stopOffset())) { - return id; + return; } - Val* mutated_val = new IterDomain( + Val* mutated_val = IrBuilder::create( + id->container(), start, extent, stop_offset, id->getParallelType(), id->getIterType(), id->isRFactorProduct()); + if (id->hasPaddingToMultipleOfWarp()) { + mutated_val->as()->padToMultipleOfWarp( + id->getMaybeSizeAfterPadding()); + } registerMutation(id, mutated_val); - return mutated_val; } -Statement* OptOutMutator::mutate(TensorDomain* td) { - std::vector dom; +void OptOutMutator::mutate(TensorDomain* td) { bool mutated = false; - for (const auto i : c10::irange(td->nDims())) { - IterDomain* id = mutateAsVal(td->axis(i))->as(); - dom.push_back(id); - if (!id->sameAs(td->axis(i))) - mutated = true; - } - if (mutated) { - Val* mutated_val = new TensorDomain( - td->getRootDomain(), td->getRFactorDomain(), dom, td->contiguity()); - registerMutation(td, mutated_val); - return mutated_val; + auto updateIdVec = [&](const std::vector& ids) { + std::vector updated_ids; + for (auto id : ids) { + auto updated_id = maybeMutated(id)->as(); + updated_ids.push_back(updated_id); + if (!updated_id->sameAs(id)) { + mutated = true; + } + } + return updated_ids; + }; + + std::vector root_dom = updateIdVec(td->getRootDomain()); + std::vector rfactor_dom = td->hasRFactor() + ? updateIdVec(td->getMaybeRFactorDomain()) + : std::vector(); + std::vector domain = updateIdVec(td->domain()); + + if (!mutated) { + return; } - return td; -} -Statement* OptOutMutator::mutate(TensorView* tv) { - TensorDomain* td = mutateAsVal(tv->domain())->as(); + Val* mutated_val = IrBuilder::create( + td->container(), root_dom, rfactor_dom, domain, td->contiguity()); + registerMutation(td, mutated_val); +} +void OptOutMutator::mutate(TensorView* tv) { + TensorDomain* td = maybeMutated(tv->domain())->as(); if (!tv->domain()->sameAs(td)) { - TensorView* mutated_tv = new TensorView(td, tv->getDataType().value()); - registerMutation(tv, mutated_tv); - return mutated_tv; + tv->setDomain(td); } - return tv; -} - -Statement* OptOutMutator::mutate(Bool* b) { - return b; + // Don't register tv mutations as we just want to update the TD } -Statement* OptOutMutator::mutate(Double* d) { - return d; +void OptOutMutator::mutate(kir::Predicate*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } -Statement* OptOutMutator::mutate(Int* i) { - return i; -} - -Statement* OptOutMutator::mutate(NamedScalar* ns) { - return ns; +void OptOutMutator::mutate(kir::TensorIndex*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } // MUTATE FUNCTIONS FOR EXPRESSIONS. +void OptOutMutator::mutate(UnaryOp* uop) { + Val* out = maybeMutated(uop->out()); + Val* in = maybeMutated(uop->in()); -Statement* OptOutMutator::mutate(Split* s) { - IterDomain* ot = mutateAsVal(s->outer())->as(); - IterDomain* inr = mutateAsVal(s->inner())->as(); - IterDomain* in = mutateAsVal(s->in())->as(); - Val* fact = mutateAsVal(s->factor())->as(); - - if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && - in->sameAs(s->in()) && areEqualScalars(fact, s->factor())) { - return s; + if (out->sameAs(uop->out()) && in->sameAs(uop->in())) { + return; } - FusionGuard::getCurFusion()->removeExpr(s); - return new Split(ot, inr, in, fact, s->innerSplit()); + auto container = uop->container(); + auto uop_type = uop->getUnaryOpType(); + container->removeExpr(uop); + IrBuilder::create(container, uop_type, out, in); } -Statement* OptOutMutator::mutate(Merge* m) { - IterDomain* ot = mutateAsVal(m->out())->as(); - IterDomain* otr = mutateAsVal(m->outer())->as(); - IterDomain* in = mutateAsVal(m->inner())->as(); +void OptOutMutator::mutate(BinaryOp* bop) { + Val* out = maybeMutated(bop->out()); + Val* lhs = maybeMutated(bop->lhs()); + Val* rhs = maybeMutated(bop->rhs()); - if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && in->sameAs(m->inner())) - return m; - - FusionGuard::getCurFusion()->removeExpr(m); - return new Merge(ot, otr, in); -} - -Statement* OptOutMutator::mutate(UnaryOp* uop) { - Val* out = mutateAsVal(uop->out())->asVal(); - Val* in = mutateAsVal(uop->in())->asVal(); + if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs()) { + return; + } - if (out->sameAs(uop->out()) && in->sameAs(uop->in())) - return uop; - FusionGuard::getCurFusion()->removeExpr(uop); - return new UnaryOp(uop->getUnaryOpType(), out, in); + auto container = bop->container(); + auto bop_type = bop->getBinaryOpType(); + container->removeExpr(bop); + IrBuilder::create(container, bop_type, out, lhs, rhs); } -Statement* OptOutMutator::mutate(BinaryOp* bop) { - Val* out = mutateAsVal(bop->out())->asVal(); - Val* lhs = mutateAsVal(bop->lhs())->asVal(); - Val* rhs = mutateAsVal(bop->rhs())->asVal(); - if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs()) - return bop; - FusionGuard::getCurFusion()->removeExpr(bop); - return new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); -} +void OptOutMutator::mutate(TernaryOp* top) { + Val* out = maybeMutated(top->out()); + Val* in1 = maybeMutated(top->in1()); + Val* in2 = maybeMutated(top->in2()); + Val* in3 = maybeMutated(top->in3()); -Statement* OptOutMutator::mutate(TernaryOp* top) { - Val* out = mutateAsVal(top->out())->asVal(); - Val* in1 = mutateAsVal(top->in1())->asVal(); - Val* in2 = mutateAsVal(top->in2())->asVal(); - Val* in3 = mutateAsVal(top->in3())->asVal(); if (out == top->out() && in1 == top->in1() && in2 == top->in2() && - in3 == top->in3()) - return top; - FusionGuard::getCurFusion()->removeExpr(top); - return new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3); + in3 == top->in3()) { + return; + } + + auto container = top->container(); + auto top_type = top->getTernaryOpType(); + container->removeExpr(top); + IrBuilder::create(container, top_type, out, in1, in2, in3); } -Statement* OptOutMutator::mutate(ReductionOp* rop) { - Val* out = mutateAsVal(rop->out())->asVal(); - Val* in = mutateAsVal(rop->in())->asVal(); +void OptOutMutator::mutate(ReductionOp* rop) { + Val* out = maybeMutated(rop->out()); + Val* in = maybeMutated(rop->in()); Val* init = rop->init(); if (out->sameAs(rop->out()) && in->sameAs(rop->in()) && - init->sameAs(rop->init())) - return rop; + init->sameAs(rop->init())) { + return; + } - return new ReductionOp(rop->getReductionOpType(), init, out, in); + auto container = rop->container(); + auto rop_type = rop->getReductionOpType(); + container->removeExpr(rop); + IrBuilder::create(container, rop_type, init, out, in); } namespace { @@ -159,20 +194,18 @@ inline bool compareOptional(Val* a, Val* b) { } // namespace -Statement* OptOutMutator::mutate(WelfordOp* wop) { - Val* out_avg = mutateAsVal(wop->outAvg())->asVal(); - Val* out_var = mutateAsVal(wop->outVar())->asVal(); - Val* out_N = mutateAsVal(wop->outN())->asVal(); +void OptOutMutator::mutate(WelfordOp* wop) { + Val* out_avg = maybeMutated(wop->outAvg()); + Val* out_var = maybeMutated(wop->outVar()); + Val* out_N = maybeMutated(wop->outN()); - Val* in_avg = mutateAsVal(wop->inAvg())->asVal(); - Val* in_var = wop->inVar() ? mutateAsVal(wop->inVar())->asVal() : nullptr; - Val* in_N = mutateAsVal(wop->inN())->asVal(); + Val* in_avg = maybeMutated(wop->inAvg()); + Val* in_var = wop->inVar() ? maybeMutated(wop->inVar()) : nullptr; + Val* in_N = maybeMutated(wop->inN()); - Val* init_avg = - wop->initAvg() ? mutateAsVal(wop->initAvg())->asVal() : nullptr; - Val* init_var = - wop->initVar() ? mutateAsVal(wop->initVar())->asVal() : nullptr; - Val* init_N = mutateAsVal(wop->initN())->asVal(); + Val* init_avg = wop->initAvg() ? maybeMutated(wop->initAvg()) : nullptr; + Val* init_var = wop->initVar() ? maybeMutated(wop->initVar()) : nullptr; + Val* init_N = maybeMutated(wop->initN()); const bool out_compare = out_avg->sameAs(wop->outAvg()) && out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN()); @@ -182,56 +215,163 @@ Statement* OptOutMutator::mutate(WelfordOp* wop) { compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN()); if (out_compare && init_compare && in_compare) { - return wop; - } else { - return new WelfordOp( - out_avg, - out_var, - out_N, - init_avg, - init_var, - init_N, - in_avg, - in_var, - in_N); + return; } + + auto container = wop->container(); + container->removeExpr(wop); + IrBuilder::create( + container, + out_avg, + out_var, + out_N, + init_avg, + init_var, + init_N, + in_avg, + in_var, + in_N); } -Statement* OptOutMutator::mutate(BroadcastOp* bop) { - return bop; +void OptOutMutator::mutate(BroadcastOp* bop) { + Val* out = maybeMutated(bop->out()); + Val* in = maybeMutated(bop->in()); + + if (out->sameAs(bop->out()) && in->sameAs(bop->in())) { + return; + } + + auto container = bop->container(); + auto flags = bop->getBroadcastDimFlags(); + container->removeExpr(bop); + IrBuilder::create(container, out, in, flags); } -Statement* OptOutMutator::mutate(TransposeOp* top) { - return top; +void OptOutMutator::mutate(TransposeOp* top) { + TensorView* out = maybeMutated(top->out())->as(); + TensorView* in = maybeMutated(top->in())->as(); + + if (out->sameAs(top->out()) && in->sameAs(top->in())) { + return; + } + + auto container = top->container(); + auto new2old = top->new2old(); + container->removeExpr(top); + IrBuilder::create(container, out, in, new2old); } -Statement* OptOutMutator::mutate(ShiftOp* sop) { - Val* out = mutateAsVal(sop->out())->asVal(); - Val* in = mutateAsVal(sop->in())->asVal(); +void OptOutMutator::mutate(ShiftOp* sop) { + Val* out = maybeMutated(sop->out())->asVal(); + Val* in = maybeMutated(sop->in())->asVal(); + + if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { + return; + } - if (out->sameAs(sop->out()) && in->sameAs(sop->in())) - return sop; auto offsets = sop->offsets(); - FusionGuard::getCurFusion()->removeExpr(sop); - return new ShiftOp(out, in, offsets, sop->pad()); + auto pad_width = sop->padWidth(); + auto container = sop->container(); + container->removeExpr(sop); + IrBuilder::create(container, out, in, offsets, pad_width); } -Statement* OptOutMutator::mutate(GatherOp* op) { - Val* out = mutateAsVal(op->out())->asVal(); - Val* in = mutateAsVal(op->in())->asVal(); +void OptOutMutator::mutate(GatherOp* op) { + Val* out = maybeMutated(op->out())->asVal(); + Val* in = maybeMutated(op->in())->asVal(); + + if (out->sameAs(op->out()) && in->sameAs(op->in())) { + return; + } - if (out->sameAs(op->out()) && in->sameAs(op->in())) - return op; auto window_shape = op->windowShape(); auto pad_width = op->padWidth(); - FusionGuard::getCurFusion()->removeExpr(op); - return new GatherOp(out, in, window_shape, pad_width); + auto container = op->container(); + container->removeExpr(op); + IrBuilder::create(container, out, in, window_shape, pad_width); +} + +void OptOutMutator::mutate(ViewOp* vop) { + TensorView* out = maybeMutated(vop->out())->as(); + TensorView* in = maybeMutated(vop->in())->as(); + + if (out->sameAs(vop->out()) && in->sameAs(vop->in())) { + return; + } + + auto container = vop->container(); + container->removeExpr(vop); + IrBuilder::create(container, out, in); +} + +void OptOutMutator::mutate(Split* s) { + IterDomain* ot = maybeMutated(s->outer())->as(); + IterDomain* inr = maybeMutated(s->inner())->as(); + IterDomain* in = maybeMutated(s->in())->as(); + Val* fact = maybeMutated(s->factor())->as(); + Val* start_offset = maybeMutated(s->startOffset()); + Val* stop_offset = maybeMutated(s->stopOffset()); + + if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && + in->sameAs(s->in()) && areEqualScalars(fact, s->factor()) && + start_offset->sameAs(s->startOffset()) && + stop_offset->sameAs(s->stopOffset())) { + return; + } + + auto container = s->container(); + auto inner_split = s->innerSplit(); + container->removeExpr(s); + auto new_node = IrBuilder::create( + container, ot, inr, in, fact, inner_split, start_offset, stop_offset); +} + +void OptOutMutator::mutate(Merge* m) { + IterDomain* ot = maybeMutated(m->out())->as(); + IterDomain* otr = maybeMutated(m->outer())->as(); + IterDomain* in = maybeMutated(m->inner())->as(); + + if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && + in->sameAs(m->inner())) { + return; + } + + auto container = m->container(); + container->removeExpr(m); + auto new_node = IrBuilder::create(container, ot, otr, in); } -Statement* OptOutMutator::mutate(ViewOp* vop) { - return vop; +void OptOutMutator::mutate(kir::Allocate*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::Sync*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::InitMagicZero*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::UpdateMagicZero*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::ForLoop*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::IfThenElse*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::GridReduction*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::GridBroadcast*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::GridWelford*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } +void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { + container->removeExpr(expr); +} } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/mutator.h b/torch/csrc/jit/codegen/cuda/mutator.h index f9ec40ca9f5..433de485cf1 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.h +++ b/torch/csrc/jit/codegen/cuda/mutator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/non_divisible_split.h b/torch/csrc/jit/codegen/cuda/non_divisible_split.h index f17bf2d6246..6706c9f072d 100644 --- a/torch/csrc/jit/codegen/cuda/non_divisible_split.h +++ b/torch/csrc/jit/codegen/cuda/non_divisible_split.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.cpp b/torch/csrc/jit/codegen/cuda/ops/alias.cpp new file mode 100644 index 00000000000..14aff510911 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ops/alias.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +//! Transform TensorView according to keep, merge, and split transformations. +//! Trivial reduction and broadcast transformations are handled separately. +//! It is recommend to use the composite ops view function, which will call +//! the analyzeView function to generate the appropriate transformations. +//! +//! For example: +//! original sizes = [2, 10, 40] +//! new_size = [2, 10, 2, 20] +//! auto analysis = analyzeView(TV0, original_sizes, new_sizes) +//! auto TV1 = TV0->view(analysis.transforms); +//! +//! Transforms = [(Keep I0), (Keep I1), (Split I2 by 2)] +//! Before: TV0[I0, I1, I2] +//! After: TV0[I0, I1, 2, ceilDiv(I2, 2)] +//! +TensorView* applyViewTransforms( + TensorView* tv, + const std::vector>& transforms) { + TORCH_INTERNAL_ASSERT( + !tv->hasComputeAt(), + "Cannot modify rfactor domain after compute at has been set."); + + TORCH_INTERNAL_ASSERT(tv->nDims() > 0, "Tried to view a 0-dim TensorView"); + + TORCH_CHECK( + !tv->domain()->hasRFactor(), + "Cannot call view on the same TensorView twice."); + + TORCH_INTERNAL_ASSERT(!transforms.empty()); + + TensorView* consumer = IrBuilder::create( + tv->container(), + tv->domain()->view(transforms), + tv->getDataType().value()); + + IrBuilder::create(tv->container(), consumer, tv); + + return consumer; +} + +} // namespace + +TensorView* view( + TensorView* x, + const std::vector& original_sizes, + const std::vector& new_sizes) { + TORCH_INTERNAL_ASSERT(x->nDims() == original_sizes.size()); + + auto analyze_view = analyzeView(x, original_sizes, new_sizes); + + auto reduction = (!analyze_view.trivial_reduction_axes.empty()) + ? sum(x, analyze_view.trivial_reduction_axes) + : x; + + auto view = (!analyze_view.transforms.empty()) + ? applyViewTransforms(reduction, analyze_view.transforms) + : reduction; + + return (analyze_view.has_broadcast) + ? broadcast(view, analyze_view.broadcast_axes) + : view; +} + +TensorView* squeeze(TensorView* x, const std::vector& sizes) { + TORCH_INTERNAL_ASSERT(x->nDims() == sizes.size()); + + std::vector trivial_reduction_axes; + for (const auto idx : c10::irange(sizes.size())) { + if (sizes[idx] == 1) { + trivial_reduction_axes.push_back(idx); + } + } + return (trivial_reduction_axes.empty()) ? x : sum(x, trivial_reduction_axes); +} + +TensorView* squeeze(TensorView* x, const std::vector& sizes, int dim) { + TORCH_INTERNAL_ASSERT(x->nDims() == sizes.size()); + if (dim < 0) { + dim = (int)(x->nDims()) + dim; + } + TORCH_INTERNAL_ASSERT(dim >= 0 && dim < x->nDims()); + if (sizes[dim] == 1) { + return sum(x, {dim}); + } else { + return set(x); + } +} + +TensorView* unsqueeze(TensorView* x, int dim) { + if (dim < 0) { + dim = (int)(x->nDims()) + dim + 1; + } + TORCH_INTERNAL_ASSERT(dim >= 0 && dim <= x->nDims()); + + std::vector broadcast_axes(x->nDims() + 1, false); + broadcast_axes[dim] = true; + return broadcast(x, broadcast_axes); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.h b/torch/csrc/jit/codegen/cuda/ops/alias.h new file mode 100644 index 00000000000..8003e3268b3 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ops/alias.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include +#include + +// +// The operations defined in this header is intended as user facing functions. +// The user will provide the necessary input TensorViews and the function will +// create the correct intermediate nodes and return the output TensorViews. +// + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +TORCH_CUDA_CU_API TensorView* view( + TensorView* x, + const std::vector& original_sizes, + const std::vector& new_sizes); + +TORCH_CUDA_CU_API TensorView* squeeze( + TensorView* x, + const std::vector& sizes); + +TORCH_CUDA_CU_API TensorView* squeeze( + TensorView* x, + const std::vector& sizes, + int dim); + +TORCH_CUDA_CU_API TensorView* unsqueeze(TensorView* x, int dim); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ops/all_ops.h b/torch/csrc/jit/codegen/cuda/ops/all_ops.h index 1ebd2bb87f1..07d3eb944e8 100644 --- a/torch/csrc/jit/codegen/cuda/ops/all_ops.h +++ b/torch/csrc/jit/codegen/cuda/ops/all_ops.h @@ -1,4 +1,5 @@ #pragma once #include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.cpp b/torch/csrc/jit/codegen/cuda/ops/composite.cpp index 06bcf2d0494..c01b7230625 100644 --- a/torch/csrc/jit/codegen/cuda/ops/composite.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/composite.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -8,9 +9,10 @@ namespace fuser { namespace cuda { ForwardDropoutResult dropout(TensorView* x, Val* prob) { - auto p1m = sub(new Double(1.), prob); - auto zero_check = add(eq(p1m, new Double(0.)), p1m); - auto scale = div(new Double(1.), zero_check); + auto p1m = sub(IrBuilder::create(x->container(), 1.), prob); + auto zero_check = + add(eq(p1m, IrBuilder::create(x->container(), 0.)), p1m); + auto scale = div(IrBuilder::create(x->container(), 1.), zero_check); return dropout(x, p1m, scale); } @@ -91,13 +93,14 @@ Val* fast_gelu(Val* x) { auto x_cube = mul(x, mul(x, x)); - auto inner_1 = mul(new Double(kKappa), x_cube); + auto inner_1 = mul(IrBuilder::create(x->container(), kKappa), x_cube); auto inner_2 = add(x, inner_1); - auto inner_3 = mul(new Double(kBeta), inner_2); + auto inner_3 = mul(IrBuilder::create(x->container(), kBeta), inner_2); auto tanh_inner = tanh(inner_3); - auto out = mul(x, add(new Double(1.), tanh_inner)); - auto y = mul(new Double(0.5), out); + auto out = + mul(x, add(IrBuilder::create(x->container(), 1.), tanh_inner)); + auto y = mul(IrBuilder::create(x->container(), 0.5), out); return y; } @@ -111,21 +114,25 @@ Val* fast_gelu_backward(Val* dy, Val* x) { auto x_sq = mul(x, x); auto x_cube = mul(x, x_sq); - auto inner_1 = mul(new Double(kKappa), x_cube); + auto inner_1 = mul(IrBuilder::create(x->container(), kKappa), x_cube); auto inner_2 = add(x, inner_1); - auto inner_3 = mul(new Double(kBeta), inner_2); + auto inner_3 = mul(IrBuilder::create(x->container(), kBeta), inner_2); auto tanh_inner = tanh(inner_3); - auto left = mul(new Double(0.5), x); - auto right = add(new Double(1.), tanh_inner); + auto left = mul(IrBuilder::create(x->container(), 0.5), x); + auto right = add(IrBuilder::create(x->container(), 1.), tanh_inner); - auto left_derivative = mul(new Double(0.5), right); + auto left_derivative = + mul(IrBuilder::create(x->container(), 0.5), right); auto tanh_inner_sq = mul(tanh_inner, tanh_inner); - auto tanh_derivative = sub(new Double(1), tanh_inner_sq); + auto tanh_derivative = + sub(IrBuilder::create(x->container(), 1), tanh_inner_sq); - auto constant_mul_x_sq = mul(new Double(kBeta * 3 * kKappa), x_sq); - auto inner_derivative = add(new Double(kBeta), constant_mul_x_sq); + auto constant_mul_x_sq = + mul(IrBuilder::create(x->container(), kBeta * 3 * kKappa), x_sq); + auto inner_derivative = + add(IrBuilder::create(x->container(), kBeta), constant_mul_x_sq); auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative)); auto dx = mul(dy, add(left_derivative, right_derivative)); @@ -139,79 +146,30 @@ Val* gelu_backward(Val* dy, Val* x) { constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5; const double kHalf = 0.5; - auto cdf_1 = mul(x, new Double(M_SQRT1_2)); + auto cdf_1 = mul(x, IrBuilder::create(x->container(), M_SQRT1_2)); auto cdf_2 = erf(cdf_1); - auto cdf_3 = add(cdf_2, new Double(1.)); - auto cdf_4 = mul(cdf_3, new Double(kHalf)); + auto cdf_3 = add(cdf_2, IrBuilder::create(x->container(), 1.)); + auto cdf_4 = mul(cdf_3, IrBuilder::create(x->container(), kHalf)); auto pdf_1 = mul(x, x); - auto pdf_2 = mul(pdf_1, new Double(-kHalf)); + auto pdf_2 = mul(pdf_1, IrBuilder::create(x->container(), -kHalf)); auto pdf_3 = exp(pdf_2); - auto out = addcmul(cdf_4, x, pdf_3, new Double(kAlpha)); + auto out = addcmul( + cdf_4, x, pdf_3, IrBuilder::create(x->container(), kAlpha)); auto dx = mul(out, dy); return dx; } -namespace { - -//! Transform TensorView according to keep, merge, and split transformations. -//! Trivial reduction and broadcast transformations are handled separately. -//! It is recommend to use the composite ops view function, which will call -//! the analyzeView function to generate the appropriate transformations. -//! -//! For example: -//! original sizes = [2, 10, 40] -//! new_size = [2, 10, 2, 20] -//! auto analysis = analyzeView(TV0, original_sizes, new_sizes) -//! auto TV1 = TV0->view(analysis.transforms); -//! -//! Transforms = [(Keep I0), (Keep I1), (Split I2 by 2)] -//! Before: TV0[I0, I1, I2] -//! After: TV0[I0, I1, 2, ceilDiv(I2, 2)] -//! -TensorView* applyViewTransforms( - TensorView* tv, - const std::vector>& transforms) { - TORCH_INTERNAL_ASSERT( - !tv->hasComputeAt(), - "Cannot modify rfactor domain after compute at has been set."); - - TORCH_INTERNAL_ASSERT(tv->nDims() > 0, "Tried to view a 0-dim TensorView"); - - TORCH_CHECK( - !tv->domain()->hasRFactor(), - "Cannot call view on the same TensorView twice."); - - TORCH_INTERNAL_ASSERT(!transforms.empty()); - - TensorView* consumer = - new TensorView(tv->domain()->view(transforms), tv->getDataType().value()); - - new ViewOp(consumer, tv); - - return consumer; -} - -} // namespace - -TensorView* view( - TensorView* x, - const std::vector& original_sizes, - const std::vector& new_sizes) { - auto analyze_view = analyzeView(x, original_sizes, new_sizes); - - auto reduction = (!analyze_view.trivial_reduction_axes.empty()) - ? sum(x, analyze_view.trivial_reduction_axes) - : x; - - auto view = (!analyze_view.transforms.empty()) - ? applyViewTransforms(reduction, analyze_view.transforms) - : reduction; +Val* tanh_backward(Val* dy, Val* tanh_x) { + TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid."); + TORCH_INTERNAL_ASSERT(tanh_x != nullptr, "Input is invalid"); - return (analyze_view.has_broadcast) - ? broadcast(view, analyze_view.broadcast_axes) - : view; + auto one = IrBuilder::create(tanh_x->container(), 1.); + auto tanh_sq = mul(tanh_x, tanh_x); + auto sub_tanh_sq = sub(one, tanh_sq); + auto dx = mul(dy, sub_tanh_sq); + return dx; } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.h b/torch/csrc/jit/codegen/cuda/ops/composite.h index 4470f0cc6f0..63e17629f40 100644 --- a/torch/csrc/jit/codegen/cuda/ops/composite.h +++ b/torch/csrc/jit/codegen/cuda/ops/composite.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -48,11 +48,7 @@ TORCH_CUDA_CU_API LstmResult lstm( TORCH_CUDA_CU_API Val* fast_gelu(Val* x); TORCH_CUDA_CU_API Val* fast_gelu_backward(Val* dy, Val* x); TORCH_CUDA_CU_API Val* gelu_backward(Val* dy, Val* x); - -TORCH_CUDA_CU_API TensorView* view( - TensorView* x, - const std::vector& x_sizes, - const std::vector& new_sizes); +TORCH_CUDA_CU_API Val* tanh_backward(Val* dy, Val* tanh_x); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp index 19201687553..4a473f66203 100644 --- a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace torch { @@ -23,7 +24,7 @@ TensorView* softmax(TensorView* x, int dim) { auto exp_val = exp(x_max_sub); auto sum_exp = sum(exp_val, {kReductionAxis}); auto bcast_sum = broadcast(sum_exp, broadcast_mask); - auto y = div(exp_val, bcast_sum); + auto y = mul(exp_val, reciprocal(bcast_sum)); return y; } @@ -88,7 +89,7 @@ ForwardNormResult layer_norm( std::vector inner_reduction_axes(kNormShapeNumDims); std::vector inner_broadcast_mask(kNumberOfDims, false); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(x->container(), 1); for (const auto idx : c10::irange(kNormShapeNumDims)) { const size_t axis = kNumberOfDims - 1 - idx; inner_reduction_axes[idx] = axis; @@ -102,7 +103,7 @@ ForwardNormResult layer_norm( auto x_sub_mean = sub(x, mean_bcast); auto var_sum_bcast = broadcast(welford_out.var_sum, inner_broadcast_mask); - auto var = div(var_sum_bcast, num_features); + auto var = mul(var_sum_bcast, reciprocal(num_features)); auto var_eps = add(var, eps); auto invstd = rsqrt(var_eps); @@ -156,7 +157,7 @@ BackwardNormResult layer_norm_backward( std::vector inner_reduction_axes(kNormShapeNumDims); std::vector inner_broadcast_mask(kNumberOfDims, false); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(x->container(), 1); for (const auto idx : c10::irange(kNormShapeNumDims)) { const size_t axis = kNumberOfDims - 1 - idx; inner_reduction_axes[idx] = axis; @@ -243,7 +244,7 @@ ForwardNormResult batch_norm( std::vector reduction_axes; std::vector broadcast_mask(kNumberOfDims, false); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(x->container(), 1); for (const auto axis : c10::irange(kNumberOfDims)) { if (axis != c_axis) { @@ -267,13 +268,15 @@ ForwardNormResult batch_norm( kTraining, "When running stats are provided, batch stats should only be computed during training"); - auto rev_momentum = sub(new Double(1.0), momentum); + auto rev_momentum = + sub(IrBuilder::create(x->container(), 1.0), momentum); auto current_mean_hat = mul(welford_out.avg, momentum); auto mean_hat = mul(running_mean, rev_momentum); auto new_mean_hat = add(mean_hat, current_mean_hat); - auto num_feature_decrement = sub(num_features, new Int(1)); - auto unbiased_var = div(welford_out.var_sum, num_feature_decrement); + auto num_feature_decrement = sub(num_features, x->container()->oneVal()); + auto unbiased_var = + mul(welford_out.var_sum, reciprocal(num_feature_decrement)); auto current_var_hat = mul(unbiased_var, momentum); auto var_hat = mul(running_var, rev_momentum); auto new_var_hat = add(var_hat, current_var_hat); @@ -301,14 +304,14 @@ ForwardNormResult batch_norm( fusion->aliasOutputToInput(casted_output, input_to_cast); }; - if (fusion->hasInput(running_mean)) { + if (running_mean->isFusionInput()) { fusion->addOutput(new_mean_hat); fusion->aliasOutputToInput(new_mean_hat, running_mean); } else { cast_to_input_dtype(running_mean, new_mean_hat); } - if (fusion->hasInput(running_var)) { + if (running_var->isFusionInput()) { fusion->addOutput(new_var_hat); fusion->aliasOutputToInput(new_var_hat, running_var); } else { @@ -320,7 +323,7 @@ ForwardNormResult batch_norm( auto mean_bcast = broadcast(mean, broadcast_mask); auto x_sub_mean = sub(x, mean_bcast); - auto var = div(welford_out.var_sum, num_features); + auto var = mul(welford_out.var_sum, reciprocal(num_features)); auto var_eps = add(var, eps); invstd = rsqrt(var_eps); auto invstd_bcast = broadcast(invstd, broadcast_mask); @@ -414,19 +417,6 @@ BackwardNormResult batch_norm_backward( mean = broadcast(mean, broadcast_mask); - TensorView* weight_val = nullptr; - if (weight == nullptr) { - weight_val = TensorViewBuilder() - .ndims(kNumberOfDims) - .dtype(input->getDataType().value()) - .shape(std::vector(kNumberOfDims, 1)) - .build(); - new UnaryOp( - UnaryOpType::Set, weight_val->as(), (new Double(1.0))->as()); - } else { - weight_val = broadcast(weight, broadcast_mask); - } - auto norm = reciprocal(num_features); auto grad_output_sum = sum(grad_output, reduction_axes); @@ -435,7 +425,16 @@ BackwardNormResult batch_norm_backward( auto grad_mean = broadcast(mul(grad_output_sum, norm), broadcast_mask); auto proj_scale = broadcast(mul(mul(dot_p, norm), mul(invstd, invstd)), broadcast_mask); - auto grad_scale = mul(broadcast(invstd, broadcast_mask), weight_val); + TensorView* grad_scale = nullptr; + + if (weight == nullptr) { + grad_scale = + mul(broadcast(invstd, broadcast_mask), + IrBuilder::create(input->container(), 1)); + } else { + grad_scale = mul( + broadcast(invstd, broadcast_mask), broadcast(weight, broadcast_mask)); + } TensorView* grad_input = nullptr; if (kTraining) { @@ -496,7 +495,7 @@ ForwardNormResult instance_norm( std::vector x_reduction_axes; std::vector x_broadcast_mask(kNumberOfDims, false); - Val* N = new Double(1); + Val* N = IrBuilder::create(x->container(), 1); for (const auto axis : c10::irange(kNumberOfDims)) { if (axis != kBatchDim && axis != kChannelsDim) { x_reduction_axes.push_back(axis); @@ -504,7 +503,7 @@ ForwardNormResult instance_norm( N = mul(N, x->domain()->domain()[axis]->extent()); } } - Val* B = new Double(1); + Val* B = IrBuilder::create(x->container(), 1); B = mul(B, x->domain()->domain()[kBatchDim]->extent()); std::vector channels_only_broadcast_mask(kNumberOfDims, false); @@ -523,7 +522,8 @@ ForwardNormResult instance_norm( // updating running mean and running var if (running_mean != nullptr && running_var != nullptr) { - auto rev_momentum = sub(new Double(1.0), momentum); + auto rev_momentum = + sub(IrBuilder::create(x->container(), 1.0), momentum); auto current_mean_hat = mul(welford_out.avg, momentum); auto mean_hat = mul(running_mean, rev_momentum); auto new_mean_hat = add(mean_hat, current_mean_hat); @@ -531,12 +531,13 @@ ForwardNormResult instance_norm( // NS: static_cast to workaround VC++ error, see // https://godbolt.org/z/6Prd77xYs auto new_mean_sum = sum(new_mean_hat, {static_cast(kBatchDim)}); - auto new_mean_channels_only = div(new_mean_sum, B); + auto new_mean_channels_only = mul(new_mean_sum, reciprocal(B)); fusion->addOutput(new_mean_channels_only); fusion->aliasOutputToInput(new_mean_channels_only, running_mean); - auto num_feature_decrement = sub(N, new Int(1)); - auto unbiased_var = div(welford_out.var_sum, num_feature_decrement); + auto num_feature_decrement = sub(N, x->container()->oneVal()); + auto unbiased_var = + mul(welford_out.var_sum, reciprocal(num_feature_decrement)); auto current_var_hat = mul(unbiased_var, momentum); auto var_hat = mul(running_var, rev_momentum); auto new_var_hat = add(var_hat, current_var_hat); @@ -544,7 +545,7 @@ ForwardNormResult instance_norm( // NS: static_cast to workaround VC++ error, see // https://godbolt.org/z/6Prd77xYs auto new_var_sum = sum(new_var_hat, {static_cast(kBatchDim)}); - auto new_var_channels_only = div(new_var_sum, B); + auto new_var_channels_only = mul(new_var_sum, reciprocal(B)); fusion->addOutput(new_var_channels_only); fusion->aliasOutputToInput(new_var_channels_only, running_var); } @@ -553,7 +554,7 @@ ForwardNormResult instance_norm( auto mean_bcast = broadcast(mean, x_broadcast_mask); auto x_sub_mean = sub(x, mean_bcast); - auto var = div(welford_out.var_sum, N); + auto var = mul(welford_out.var_sum, reciprocal(N)); auto var_eps = add(var, eps); invstd = rsqrt(var_eps); auto invstd_bcast = broadcast(invstd, x_broadcast_mask); diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.h b/torch/csrc/jit/codegen/cuda/ops/normalization.h index dae58462b92..b28cdf6b33c 100644 --- a/torch/csrc/jit/codegen/cuda/ops/normalization.h +++ b/torch/csrc/jit/codegen/cuda/ops/normalization.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp index 3dcb58335a4..d966fc21a97 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp +++ b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include #include @@ -102,7 +100,6 @@ void ParallelDimensionMap::populateDimensionMapWithSingleCASet( TORCH_INTERNAL_ASSERT(dom_set.size() == 1); const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // pt is used by only one concrete domain auto id = *dom_set.begin(); @@ -110,16 +107,16 @@ void ParallelDimensionMap::populateDimensionMapWithSingleCASet( if (it != constant_extent_map_.end()) { if (it->second.size() == 1) { - dim_map_.insert({pt, ir_builder.create(*(it->second.begin()))}); + dim_map_.insert({pt, IrBuilder::create(*(it->second.begin()))}); exact_types_.insert(pt); } else { // Multiple constant dimensions found; Use the corresponding // symbolic parallel dim - dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)}); + dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); } } else { // Prefer to use blockDim/gridDim if not constant - dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)}); + dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); exact_types_.insert(pt); } } @@ -130,11 +127,10 @@ void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( TORCH_INTERNAL_ASSERT(dom_set.size() > 1); const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); bool all_equal = true; // Use nullptr to signal it's not initialied yet - kir::Val* known_dimension = nullptr; + Val* known_dimension = nullptr; // Use -1 to signal it's not initialied yet int64_t known_const = -1; @@ -172,7 +168,7 @@ void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( // At this point, it still remains undetermined whether this id // matches with those previously looked at. Constant check failed, // but symbolic matching may succeed. - auto this_dimension = gpu_lower->lowerValue(concrete_id->extent()); + auto this_dimension = concrete_id->extent(); if (known_dimension == nullptr) { // No previous dimension found yet known_dimension = this_dimension; @@ -191,15 +187,14 @@ void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( } // Use the const value, if found, as its dimension if (all_equal && known_const != -1) { - dim_map_.insert({pt, ir_builder.create(known_const)}); + dim_map_.insert({pt, IrBuilder::create(known_const)}); } else { - dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)}); + dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); } } void ParallelDimensionMap::adjustMappingsForWarpPadding() { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // If TIDx is padded to a multiple of the warp size, mark it as // non-exact. @@ -215,7 +210,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { // If the dimension of TIDx is actually a multple of the warp size // before padding, it can be left as exact if (isExact(tidx_pt)) { - auto tidx_dim = dynamic_cast(get(tidx_pt)); + auto tidx_dim = dynamic_cast(get(tidx_pt)); if (tidx_dim && tidx_dim->isConst()) { auto tidx_dim_val = tidx_dim->value().value(); if (tidx_dim_val % warp_size == 0) { @@ -229,17 +224,17 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { // single warp, use the constant warp size as the dimension of // TIDx. Otherwise, jsut use blockDim.x. if (warp_info.is_tidx_single_warp) { - dim_map_.at(ParallelType::TIDx) = ir_builder.create(warp_size); + dim_map_.at(ParallelType::TIDx) = IrBuilder::create(warp_size); } else { dim_map_.at(ParallelType::TIDx) = - kir::NamedScalar::getParallelDim(ParallelType::TIDx); + NamedScalar::getParallelDim(ParallelType::TIDx); } // TIDx is no longer exact exact_types_.erase(ParallelType::TIDx); } -kir::Val* ParallelDimensionMap::get(ParallelType pt) const { +Val* ParallelDimensionMap::get(ParallelType pt) const { TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt); auto it = dim_map_.find(pt); if (it == dim_map_.end()) { @@ -261,7 +256,7 @@ IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) { // Symbolically compares equality of two KIR vals. Comparison is done // conservatively, so returning false does not guarantee non-equality. -bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { +bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) { TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr); if (dim1 == dim2) { @@ -269,8 +264,8 @@ bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { } // When Both are Int, they are same if both have the same constant - auto dim1_int = dynamic_cast(dim1); - auto dim2_int = dynamic_cast(dim2); + auto dim1_int = dynamic_cast(dim1); + auto dim2_int = dynamic_cast(dim2); if (dim1_int && dim2_int) { if (dim1_int->isConst() && dim2_int->isConst()) { return dim1_int->value() == dim2_int->value(); @@ -279,8 +274,8 @@ bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { // When both are NamedScalar, they are same if Both have the same // name - auto dim1_ns = dynamic_cast(dim1); - auto dim2_ns = dynamic_cast(dim2); + auto dim1_ns = dynamic_cast(dim1); + auto dim2_ns = dynamic_cast(dim2); if (dim1_ns && dim2_ns) { return dim1_ns->name() == dim2_ns->name(); } @@ -297,12 +292,12 @@ bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { // If both are BinaryOp or UnaryOp, check their inputs. Since these // Vals are IterDomain extents, UnaryOp should not occur, but // checking shouldn't be harmful. - if ((dim1_def->isA() && dim2_def->isA() && - (dim1_def->as()->operation() == - dim2_def->as()->operation())) || - (dim1_def->isA() && dim2_def->isA() && - (dim1_def->as()->operation() == - dim2_def->as()->operation()))) { + if ((dim1_def->isA() && dim2_def->isA() && + (dim1_def->as()->getBinaryOpType() == + dim2_def->as()->getBinaryOpType())) || + (dim1_def->isA() && dim2_def->isA() && + (dim1_def->as()->getUnaryOpType() == + dim2_def->as()->getUnaryOpType()))) { for (const auto i : c10::irange(dim1_def->inputs().size())) { (void)i; // Suppress unused variable warning if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) { @@ -321,7 +316,7 @@ std::string ParallelDimensionMap::toString() const { ss << pt << ": "; auto dim = get(pt); if (dim != nullptr) { - ss << kir::toString(dim); + ss << dim->toString(); if (isExact(pt)) { ss << ", exact"; } else { diff --git a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h index d05c17adea2..03bd513396f 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h +++ b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h @@ -21,7 +21,7 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { //! Returns the dimension of a ParallelType. nullptr is returned if //! a ParallelType is unused. - kir::Val* get(ParallelType pt) const; + Val* get(ParallelType pt) const; //! True if the dimension of a ParallelType is known to be exact bool isExact(ParallelType pt) const; @@ -29,7 +29,7 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { std::string toString() const; //! Symbolically analyze if two extent vals are equal - static bool equalDim(kir::Val* dim1, kir::Val* dim2); + static bool equalDim(Val* dim1, Val* dim2); private: //! Register the extent of an IterDomain if its constant @@ -54,7 +54,7 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { private: //! Maps from parallel types to dimensions, which are constant if //! a unique value is found. - std::unordered_map dim_map_; + std::unordered_map dim_map_; //! Set of parallel types whose dimensions are identified to be //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; diff --git a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h index 0bf8ae39277..3bfb32d38bc 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h +++ b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 11c27cffec2..94dad076db8 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -33,25 +34,18 @@ constexpr auto kNumBinaryFloatOps = 3; constexpr auto kNumBinaryComparisonOps = 12; constexpr auto kNumBinaryCastOps = 14; -constexpr auto kNumBinaryOpsWithAlpha = 4; +constexpr auto kNumBinaryOpsWithAlpha = 6; constexpr auto kNumLerpOps = 2; constexpr auto kNumLayernormFwd = 2; constexpr auto kNumBatchnormFwd = 3; constexpr auto kNumInstancenormFwd = 1; constexpr auto kNumSumToSize = 2; constexpr auto kNumAutocastOps = 2; -// constexpr auto kNumViewSize = 2; +constexpr auto kNumAliasDimOps = 2; +constexpr auto kNumViewOps = 2; namespace { -std::vector getTensorSizes(TensorTypePtr const& tensor_type) { - TORCH_INTERNAL_ASSERT(tensor_type != nullptr, "Input must be a Tensor."); - auto optional_sizes = tensor_type->sizes().concrete_sizes(); - TORCH_INTERNAL_ASSERT( - optional_sizes.has_value(), "Missing size information for the tensor."); - return optional_sizes.value(); -} - #define REGISTER_PARSE_RULE(op, func_body, ...) \ registerParseRule( \ op, \ @@ -59,7 +53,8 @@ std::vector getTensorSizes(TensorTypePtr const& tensor_type) { -> void func_body, \ __VA_ARGS__) -const auto& sizeAttr = Symbol::attr("profiled_size"); +const auto& reductionSizeAttr = Symbol::attr("profiled_reduction_size"); +const auto& viewSizeAttr = Symbol::attr("profiled_view_size"); const auto& intListAttr = Symbol::attr("profiled_int_list"); const auto& intAttr = Symbol::attr("profiled_int"); const auto& boolListAttr = Symbol::attr("profiled_bool_list"); @@ -283,8 +278,9 @@ class ValueHolder { if (iter_val != vals_.end()) { return iter_val->second; } - // patching scalar value, because memory format doesn't carry real meaning. - if (!is_tensor_view_) { + // patching scalar (tensor), memory format doesn't carry meaning and should + // just return the value as-is. + if (!is_tensor_view_ || rank() == 0) { return std::get<1>(getEntry()); } MemoryFormat format_s; @@ -505,7 +501,7 @@ class IrParser { "Failure when register value: ", *(val->node()), " with type: ", - val->type()); + val->type()->repr_str()); MemoryFormat format; Val* operand = nullptr; std::tie(format, operand) = value_map_[val->unique()].getEntry(); @@ -523,7 +519,6 @@ class IrParser { (opt_dtype.value() == DataType::Half || opt_dtype.value() == DataType::BFloat16)) { Val* promoted_val = castOp(DataType::Float, operand); - // value_map_.emplace(val->unique(), ValueHolder(promoted_val, format)); value_map_[val->unique()] = ValueHolder(promoted_val, format); } } @@ -688,7 +683,9 @@ class IrParser { "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", - "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"}; + "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", + "aten::rsub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::rsub(Tensor self, Scalar other, Scalar alpha) -> Tensor"}; for (auto signature : BinaryOpWithAlpha) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( @@ -704,6 +701,10 @@ class IrParser { BinaryOpType::Add, static_cast(&add_alpha))}, {aten::sub, + std::make_pair( + BinaryOpType::Sub, + static_cast(&sub_alpha))}, + {aten::rsub, std::make_pair( BinaryOpType::Sub, static_cast(&sub_alpha))}}); @@ -723,10 +724,12 @@ class IrParser { auto out = alpha->isOneInt() ? binaryOp( op_mapping[node->kind()].first, - lhs, - rhs, + node->kind() == aten::rsub ? rhs : lhs, + node->kind() == aten::rsub ? lhs : rhs, TypePromotion::default_op_config) - : op_mapping[node->kind()].second(lhs, rhs, alpha); + : (node->kind() == aten::rsub + ? op_mapping[node->kind()].second(rhs, lhs, alpha) + : op_mapping[node->kind()].second(lhs, rhs, alpha)); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -1101,10 +1104,10 @@ class IrParser { list_val.pop_front(); Val* low = value_map.count(node->inputs()[1]->unique()) != 0 ? *value_map[node->inputs()[1]->unique()] - : new Double(std::numeric_limits::min()); + : IrBuilder::create(std::numeric_limits::min()); Val* high = value_map.count(node->inputs()[2]->unique()) != 0 ? *value_map[node->inputs()[2]->unique()] - : new Double(std::numeric_limits::max()); + : IrBuilder::create(std::numeric_limits::max()); auto out = clamp(operand, low, high); value_map.emplace(node->output()->unique(), out); @@ -1340,7 +1343,7 @@ class IrParser { running_mean = value_map[node->input(3)->unique()]->as(); TORCH_INTERNAL_ASSERT( - fusion->hasInput(running_mean), + running_mean->isFusionInput(), "IO_tensor `instance_norm::running_mean` can only be input tensor to fusion"); } @@ -1350,7 +1353,7 @@ class IrParser { running_var = value_map[node->input(4)->unique()]->as(); TORCH_INTERNAL_ASSERT( - fusion->hasInput(running_var), + running_var->isFusionInput(), "IO_tensor `instance_norm::running_var` can only be input tensor to fusion"); } @@ -1364,7 +1367,7 @@ class IrParser { Val* momentum_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto momentum = constant_as(node->input(6))) { - momentum_ptr = new Double(momentum.value()); + momentum_ptr = IrBuilder::create(momentum.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) momentum_ptr = value_map[node->input(6)->unique()]; @@ -1373,7 +1376,7 @@ class IrParser { Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(7))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; @@ -1458,7 +1461,7 @@ class IrParser { Val* momentum_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto momentum = constant_as(node->input(6))) { - momentum_ptr = new Double(momentum.value()); + momentum_ptr = IrBuilder::create(momentum.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) momentum_ptr = value_map[node->input(6)->unique()]; @@ -1467,7 +1470,7 @@ class IrParser { Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(7))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; @@ -1586,7 +1589,7 @@ class IrParser { Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(9))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; @@ -1704,7 +1707,7 @@ class IrParser { Val* eps_ptr = nullptr; if (auto eps = constant_as(node->input(4))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { eps_ptr = value_map[node->input(4)->unique()]; } @@ -2032,7 +2035,7 @@ class IrParser { keepdim.has_value(), "aten::mean cannot be fused with dynamic keepdim"); auto o_sum = sum(self, dims, keepdim.value()); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(1); for (auto axis : dims) { if (axis < 0) { axis += int(self->nDims()); @@ -2347,6 +2350,31 @@ class IrParser { nullptr); } + { + auto ptr_op = getOperatorForLiteral( + "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, + value_map[node->inputs()[0]->unique()], + value_map[node->inputs()[1]->unique()]); + auto grad_out = list_val.front(); + list_val.pop_front(); + auto self = list_val.front(); + list_val.pop_front(); + + auto grad_in = tanh_backward(grad_out, self); + value_map.emplace( + node->output()->unique(), ValueHolder(grad_in, format)); + }, + nullptr, + nullptr); + } + { auto ptr_op = getOperatorForLiteral( "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"); @@ -2392,37 +2420,111 @@ class IrParser { }); } - /* - // TODO: Enable view in parser by detecting non-alias view operation { - std::array View = { - "aten::view(Tensor(a) self, int[] size) -> Tensor(a)", - "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"}; - for (auto signature : View) { + std::array ViewOps = { + "prim::reshape_copy(Tensor self, int[] shape) -> Tensor", + "prim::view_copy(Tensor self, int[] size) -> Tensor"}; + for (auto signature : ViewOps) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { auto self_value = node->inputs()[0]; - auto self = value_map[self_value->unique()]->as(); + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + MemoryFormat::Contiguous(), value_map[self_value->unique()]); + auto self = list_val.front()->as(); + list_val.pop_front(); auto self_type = self_value->type()->cast(); TORCH_INTERNAL_ASSERT(self_type != nullptr); auto self_sizes = getTensorSizes(self_type); - auto size_optional = - constant_as>(node->input(1)); + auto view_sizes = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( - size_optional.has_value(), "The size parameter is required."); + view_sizes.has_value(), "The size parameter is required."); - auto output = view(self, self_sizes, size_optional->vec()); + auto output = view(self, self_sizes, view_sizes->vec()); + value_map.emplace(node->output()->unique(), output); + }, + [](const Node* node) -> bool { + // Reject fusing node if view_sizes contains an inferred dimension + auto view_sizes = constant_as>(node->input(1)); + TORCH_INTERNAL_ASSERT( + view_sizes.has_value(), "The size parameter is required."); + for (auto axis_size : view_sizes->vec()) { + if (axis_size == -1) { + return false; + } + } + return true; + }, + nullptr); + } + } + + { + auto ptr_op = + getOperatorForLiteral("prim::squeeze_copy(Tensor self) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + auto self_value = node->inputs()[0]; + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + MemoryFormat::Contiguous(), value_map[self_value->unique()]); + auto self = list_val.front()->as(); + list_val.pop_front(); + + auto self_type = self_value->type()->cast(); + TORCH_INTERNAL_ASSERT(self_type != nullptr); + auto self_sizes = getTensorSizes(self_type); + + auto output = squeeze(self, self_sizes); + value_map.emplace(node->output()->unique(), output); + }, + nullptr, + nullptr); + } + + { + std::array AliasOpWithDim = { + "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor", + "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor"}; + for (auto signature : AliasOpWithDim) { + auto ptr_op = getOperatorForLiteral(signature); + REGISTER_PARSE_RULE( + ptr_op, + { + auto self_value = node->inputs()[0]; + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + MemoryFormat::Contiguous(), + value_map[node->inputs()[0]->unique()]); + auto self = list_val.front()->as(); + list_val.pop_front(); + + auto dim_value = constant_as(node->input(1)); + TORCH_INTERNAL_ASSERT(dim_value.has_value(), "dim is not valid"); + + TensorView* output = nullptr; + if (node->kind() == prim::unsqueeze_copy) { + output = unsqueeze(self, dim_value.value()); + } else { + auto self_type = self_value->type()->cast(); + TORCH_INTERNAL_ASSERT(self_type != nullptr); + auto self_sizes = getTensorSizes(self_type); + output = squeeze(self, self_sizes, dim_value.value()); + } value_map.emplace(node->output()->unique(), output); }, nullptr, nullptr); } } - */ } void processJitNode(const JitOp* node) { @@ -2456,9 +2558,9 @@ class IrParser { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { - cg_val = new Double(ival.value()); + cg_val = IrBuilder::create(ival.value()); } else { - cg_val = new Double(); + cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; @@ -2467,9 +2569,9 @@ class IrParser { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { - cg_val = new Int(ival.value()); + cg_val = IrBuilder::create(ival.value()); } else { - cg_val = new Int(); + cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; @@ -2478,9 +2580,9 @@ class IrParser { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { - cg_val = new Bool(ival.value()); + cg_val = IrBuilder::create(ival.value()); } else { - cg_val = new Bool(); + cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; @@ -2496,7 +2598,11 @@ class IrParser { // TODO: we don't support list type in codegen yet; // This is a WAR to allow axes of reduction to be passed as constant list; // We simply ignore conversion if the scalar value is a constant; - return toIValue(val).has_value(); + auto ivalue = toIValue(val); + TORCH_INTERNAL_ASSERT( + ivalue.has_value(), + "List[T] is not supported as an argument by NvFuser. Use a Constant List."); + return true; } return false; } @@ -2566,7 +2672,10 @@ class IrParser { tensor_type->undefined()); } - cg_val = new TensorView(tensor_type); + cg_val = IrBuilder::create(tensor_type); + if (is_cpu_scalar(*tensor_type)) { + cg_val->as()->setCpuScalar(true); + } value_map_.emplace(val->unique(), ValueHolder(cg_val, format)); return true; } @@ -2611,7 +2720,7 @@ ProfileIValueOp* insertProfileIValueOp( return pn; } -void profileSize(ProfilingRecord* pr, Node* node, size_t offset) { +void profileReductionSize(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { @@ -2631,12 +2740,14 @@ void profileSize(ProfilingRecord* pr, Node* node, size_t offset) { size_vec.clear(); } else { TORCH_INTERNAL_ASSERT( - false, "profileSize does not support data type: ", value.tagKind()); + false, + "profileReductionSize does not support data type: ", + value.tagKind()); } - if (!pn->hasAttribute(sizeAttr)) { - pn->is_(sizeAttr, size_vec); + if (!pn->hasAttribute(reductionSizeAttr)) { + pn->is_(reductionSizeAttr, size_vec); } else { - auto profiled_ints = pn->is(sizeAttr); + auto profiled_ints = pn->is(reductionSizeAttr); TORCH_INTERNAL_ASSERT( profiled_ints.size() == size_vec.size() && std::equal( @@ -2648,6 +2759,39 @@ void profileSize(ProfilingRecord* pr, Node* node, size_t offset) { pn->setCallback(ivalue_profiler); } +void profileViewSize(ProfilingRecord* pr, Node* node, size_t offset) { + auto pn = insertProfileIValueOp(node, offset, pr); + + const auto ivalue_profiler = [pr, pn](Stack& stack) { + std::lock_guard lock(pr->mutex_); + + // TODO: we don't care about merging multiple profiling runs as we don't + // support it at all; + int64_t frame_id = 0; + pop(stack, frame_id); + IValue value; + pop(stack, value); + TORCH_INTERNAL_ASSERT( + value.isIntList(), "profiling seeing the wrong data type"); + if (!pn->hasAttribute(viewSizeAttr)) { + pn->is_(viewSizeAttr, value.toIntVector()); + } else { + auto profiled_ints = pn->is(viewSizeAttr); + auto input_ints = value.toIntList(); + TORCH_INTERNAL_ASSERT( + profiled_ints.size() == input_ints.size() && + std::equal( + profiled_ints.begin(), + profiled_ints.end(), + input_ints.begin()), + "profiling ivalue doesn't support merge"); + } + push(stack, value); + }; + + pn->setCallback(ivalue_profiler); +} + void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); @@ -2943,7 +3087,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { // argument 1: reduction sizes; case 1: // TODO(profile_size): double check optional[size]? - profileSize(pr, node, offset); + profileReductionSize(pr, node, offset); break; default: return false; @@ -2951,28 +3095,52 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { return true; } - /* - // TODO: Enable view in parser by detecting non-alias view operation - static auto view_schema = + static auto reshape_schema = + getOperatorForLiteral("aten::reshape(Tensor self, int[] shape) -> Tensor") + ->schema(); + static auto reshape_copy_schema = getOperatorForLiteral( - "aten::view(Tensor(a) self, int[] size) -> Tensor(a)") + "prim::reshape_copy(Tensor self, int[] shape) -> Tensor") ->schema(); - static auto reshape_schema = + static auto view_schema = + getOperatorForLiteral("aten::view(Tensor self, int[] size) -> Tensor") + ->schema(); + static auto view_copy_schema = getOperatorForLiteral( - "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)") + "prim::view_copy(Tensor self, int[] size) -> Tensor") ->schema(); - if (node->matches(view_schema) || node->matches(reshape_schema)) { + if (node->matches(reshape_schema) || node->matches(reshape_copy_schema) || + node->matches(view_schema) || node->matches(view_copy_schema)) { switch (offset) { // argument 1: new tensor size; case 1: - profileSize(pr, node, offset); + profileViewSize(pr, node, offset); + break; + default: + return false; + } + return true; + } + + static auto squeeze_dim_schema = + getOperatorForLiteral( + "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor") + ->schema(); + static auto unsqueeze_schema = + getOperatorForLiteral( + "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor") + ->schema(); + if (node->matches(squeeze_dim_schema) || node->matches(unsqueeze_schema)) { + switch (offset) { + // argument 1: unsqueeze dim; + case 1: + profileInt(pr, node, offset); break; default: return false; } return true; } - */ static auto batch_norm_impl_index_schema = getOperatorForLiteral( diff --git a/torch/csrc/jit/codegen/cuda/parser.h b/torch/csrc/jit/codegen/cuda/parser.h index 4b2fcf50f99..6d52b325042 100644 --- a/torch/csrc/jit/codegen/cuda/parser.h +++ b/torch/csrc/jit/codegen/cuda/parser.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/partial_split_map.cpp b/torch/csrc/jit/codegen/cuda/partial_split_map.cpp index e7b6db4d165..e320e8ee373 100644 --- a/torch/csrc/jit/codegen/cuda/partial_split_map.cpp +++ b/torch/csrc/jit/codegen/cuda/partial_split_map.cpp @@ -12,7 +12,7 @@ void PartialSplitMap::build(Fusion* fusion) { auto used_vals = ir_utils::allTvs(fusion); for (auto tv : ir_utils::filterByType(used_vals)) { - auto exprs = ExprSort::getExprs( + auto exprs = StmtSort::getExprs( fusion, {tv->domain()->domain().begin(), tv->domain()->domain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // Only needs to check root domains as partial split is only @@ -24,18 +24,10 @@ void PartialSplitMap::build(Fusion* fusion) { continue; } auto root_domain = split->in(); - auto kir_root_domain = - gpu_lower->lowerValue(split->in())->as(); auto start_offset = split->startOffset(); start_offset_map_.insert({root_domain, start_offset}); - kir_start_offset_map_.insert( - {kir_root_domain, - gpu_lower->lowerValue(start_offset)->as()}); auto stop_offset = split->stopOffset(); stop_offset_map_.insert({root_domain, stop_offset}); - kir_stop_offset_map_.insert( - {kir_root_domain, - gpu_lower->lowerValue(stop_offset)->as()}); } } } @@ -49,15 +41,6 @@ Val* PartialSplitMap::getStartOffset(IterDomain* root_domain) const { } } -kir::Val* PartialSplitMap::getStartOffset(kir::IterDomain* root_domain) const { - auto it = kir_start_offset_map_.find(root_domain); - if (it == kir_start_offset_map_.end()) { - return nullptr; - } else { - return it->second; - } -} - Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const { auto it = stop_offset_map_.find(root_domain); if (it == stop_offset_map_.end()) { @@ -67,15 +50,6 @@ Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const { } } -kir::Val* PartialSplitMap::getStopOffset(kir::IterDomain* root_domain) const { - auto it = kir_stop_offset_map_.find(root_domain); - if (it == kir_stop_offset_map_.end()) { - return nullptr; - } else { - return it->second; - } -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/partial_split_map.h b/torch/csrc/jit/codegen/cuda/partial_split_map.h index be432bd5a16..8ec489915b7 100644 --- a/torch/csrc/jit/codegen/cuda/partial_split_map.h +++ b/torch/csrc/jit/codegen/cuda/partial_split_map.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -20,15 +20,11 @@ class TORCH_CUDA_CU_API PartialSplitMap { void build(Fusion* fusion); Val* getStartOffset(IterDomain* root_domain) const; - kir::Val* getStartOffset(kir::IterDomain* root_domain) const; Val* getStopOffset(IterDomain* root_domain) const; - kir::Val* getStopOffset(kir::IterDomain* root_domain) const; private: std::unordered_map start_offset_map_; - std::unordered_map kir_start_offset_map_; std::unordered_map stop_offset_map_; - std::unordered_map kir_stop_offset_map_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp index 004c836ec4e..91d68494fd4 100644 --- a/torch/csrc/jit/codegen/cuda/partition.cpp +++ b/torch/csrc/jit/codegen/cuda/partition.cpp @@ -5,12 +5,15 @@ #include #include #include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { +const c10::DeviceIndex INVALID_INDEX = -2; + namespace { bool hasNonElementWiseOperation(const Node* node) { @@ -38,26 +41,61 @@ static c10::optional getDevice(const Value* value) { // not tensor type, return false as the op is not outputing scalar. return c10::nullopt; } - return value->type()->expectRef().device(); + auto tensor_type = value->type()->expectRef(); + // special case for scalar tensor: return c10::nullopt instead of cpu device. + // this allows us to fuse scalar cpu tensor with cuda tensor, while avoid + // merging ops with pure scalar cpu tensors. + if (is_cpu_scalar(tensor_type)) { + return c10::nullopt; + } + return tensor_type.device(); } static c10::optional getDevice(const Node* node) { - auto outputs = node->outputs(); - for (auto output : outputs) { - auto device = getDevice(output); + c10::optional ret = c10::nullopt; + auto merge_devices = [&ret](const c10::optional& device) { if (device.has_value()) { - return device; + if (ret.has_value()) { + if (ret.value() != device.value()) { + // invalidate device to reflect conflicts + ret->set_index(INVALID_INDEX); + // return false to indicate early termination + return false; + } else { + // same device, do nothing + return true; + } + } else { + // initialize return device + ret = device.value(); + return true; + } + } + // no device information, do nothing + return true; + }; + for (auto val : node->inputs()) { + if (!merge_devices(getDevice(val))) { + return ret; } } - return c10::nullopt; + for (auto val : node->outputs()) { + if (!merge_devices(getDevice(val))) { + return ret; + } + } + return ret; } static bool isFusibleDevice(const Node* node, const c10::Device device) { - for (auto value : node->outputs()) { - auto output_device = getDevice(value); - if (output_device.has_value() && output_device.value() != device) { - return false; - } + TORCH_INTERNAL_ASSERT( + device.index() != INVALID_INDEX, "fusible device needs to be validate"); + auto opt_device = getDevice(node); + // we can be more relaxed here as we known that this function tries to merge + // node into an existing `device` + if (opt_device.has_value() && + (opt_device->index() == INVALID_INDEX || opt_device != device)) { + return false; } return true; } @@ -65,10 +103,12 @@ static bool isFusibleDevice(const Node* node, const c10::Device device) { // TODO: we need to check input type when we handle `to()` static bool isFusibleDevice(const Node* node) { auto device = getDevice(node); + // be conservative and only fuse cuda operations, this avoids us initializing + // operations that produces cpu scalar outputs if (!device.has_value()) { - return true; + return false; } - return device->is_cuda() && + return device->index() != INVALID_INDEX && device->is_cuda() && (at::cuda::getDeviceProperties(device->index())->major >= 7 || !hasNonElementWiseOperation(node)); } @@ -400,7 +440,7 @@ bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) { bool fused = false; // TODO: lift the restriction of not fusing producer containing reduction when // we have proper scheduling. - if (isFusibleCudaFusionGroup(node)) { + if (isFusibleNode(node)) { // ensure if the node has a designated device, it's on the same device with // fusion. // TODO: is there a danger of us fusing operations that's supposed to be on @@ -408,7 +448,6 @@ bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) { auto device = getDevice(fusion); fused = (!device.has_value() || isFusibleDevice(node, device.value())); } - return fused; } diff --git a/torch/csrc/jit/codegen/cuda/partition.h b/torch/csrc/jit/codegen/cuda/partition.h index 0d8baca4700..b295cb582e5 100644 --- a/torch/csrc/jit/codegen/cuda/partition.h +++ b/torch/csrc/jit/codegen/cuda/partition.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include /* diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index b501a6133f6..6575b374423 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include @@ -20,27 +18,23 @@ namespace cuda { namespace { -bool isTensorIndexOp(kir::Expr* expr) { +bool isTensorIndexOp(Expr* expr) { const auto& outputs = expr->outputs(); return outputs.size() >= 1 && outputs[0]->isA(); } -bool isOutputLocal(const kir::Expr* expr) { +bool isOutputLocal(const Expr* expr) { return std::all_of( - expr->outputs().begin(), - expr->outputs().end(), - [](const kir::Val* output) { - return !output->isA() || - output->as()->memoryType() == MemoryType::Local; + expr->outputs().begin(), expr->outputs().end(), [](const Val* output) { + return !output->isA() || + output->as()->getMemoryType() == MemoryType::Local; }); } } // namespace -bool ParallelizedDomainPredicate::PredicateInfo::addDomain( - kir::IterDomain* id) { - const auto gpu_lower = GpuLower::current(); - auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(id); +bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { + auto concrete_id = GpuLower::current()->caIndexMap().getConcreteMappedID(id); if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { ids_.push_back(concrete_id); return true; @@ -49,21 +43,19 @@ bool ParallelizedDomainPredicate::PredicateInfo::addDomain( } } -kir::Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - - kir::Bool* pred = nullptr; +Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { + Bool* pred = nullptr; - auto index = - ir_builder.create(stringifyThread(pt_), DataType::Int); + auto index = SimplifyingIrBuilder::create( + stringifyThread(pt_), DataType::Int); for (const auto& pred_id : ids()) { // Just sanity check that pred_id is concrete TORCH_INTERNAL_ASSERT( - pred_id == gpu_lower->caIndexMap().getConcreteMappedID(pred_id)); - auto new_pred = ir_builder.ltExpr(index, pred_id->extent()); - pred = ir_builder.andExpr(pred, new_pred)->as(); + pred_id == + GpuLower::current()->caIndexMap().getConcreteMappedID(pred_id)); + auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); + pred = SimplifyingIrBuilder::andExpr(pred, new_pred)->as(); } return pred; @@ -74,16 +66,12 @@ namespace { std::unordered_set getNonUnswitchedRootDomains( const std::vector& loops, size_t unswitched_loop_index) { - const auto gpu_lower = GpuLower::current(); - std::vector non_unswited_leaf_domains; std::transform( loops.begin(), loops.begin() + unswitched_loop_index, std::back_inserter(non_unswited_leaf_domains), - [&](kir::ForLoop* loop) { - return gpu_lower->caIndexMap().toFusion(loop->iter_domain()); - }); + [&](kir::ForLoop* loop) { return loop->iter_domain(); }); auto non_unswitched_inputs = IterVisitor::getInputsTo(non_unswited_leaf_domains); @@ -100,26 +88,23 @@ std::unordered_set getNonUnswitchedRootDomains( non_unswitched_concrete_root_domains, non_unswitched_concrete_root_domains.end()), [&](auto root_dom) { - return gpu_lower->caIndexMap().getConcreteMappedID(root_dom); + return GpuLower::current()->caIndexMap().getConcreteMappedID(root_dom); }); return non_unswitched_concrete_root_domains; } bool isFullyUnswitched( - kir::IterDomain* loop_id, + IterDomain* loop_id, const std::unordered_set& non_unswitched_root_domains) { - const auto gpu_lower = GpuLower::current(); - - auto root_vals = - IterVisitor::getInputsTo({gpu_lower->caIndexMap().toFusion(loop_id)}); + auto root_vals = IterVisitor::getInputsTo({loop_id}); auto root_domains = ir_utils::filterByType(root_vals); return std::none_of( root_domains.begin(), root_domains.end(), [&](auto root_dom) { auto concrete_root_dom = - gpu_lower->caIndexMap().getConcreteMappedID(root_dom); + GpuLower::current()->caIndexMap().getConcreteMappedID(root_dom); return non_unswitched_root_domains.count(concrete_root_dom) > 0; }); } @@ -131,12 +116,10 @@ std::unordered_map< ParallelizedDomainPredicate::PredicateInfo, TypeHash> ParallelizedDomainPredicate::getPredicateMap( - const kir::Expr* expr, + const Expr* expr, const std::vector& loops, kir::ForLoop* unswitched_loop) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto output_tvs = ir_utils::getTvs(expr->outputs()); if (output_tvs.empty()) { @@ -167,7 +150,7 @@ ParallelizedDomainPredicate::getPredicateMap( } auto loop_id = loop->iter_domain(); - auto loop_ptype = loop_id->parallelType(); + auto loop_ptype = loop_id->getParallelType(); // Not necessary to add a predicate if the paralle type is exact if (!isParallelTypeThread(loop_ptype) || @@ -193,7 +176,7 @@ ParallelizedDomainPredicate::getPredicateMap( continue; } - kir::IterDomain* tv_id = *it; + IterDomain* tv_id = *it; // If the corresponding domain is a broadcast, it's not really used. if (tv_id->isBroadcast()) { @@ -203,9 +186,9 @@ ParallelizedDomainPredicate::getPredicateMap( // If it's a root domain, it should be covered by the root // predicates, so no extra predicate is required. if (std::find( - tv->domain()->rootDomain().begin(), - tv->domain()->rootDomain().end(), - tv_id) != tv->domain()->rootDomain().end()) { + tv->domain()->getRootDomain().begin(), + tv->domain()->getRootDomain().end(), + tv_id) != tv->domain()->getRootDomain().end()) { continue; } @@ -218,26 +201,24 @@ ParallelizedDomainPredicate::getPredicateMap( return map; } -kir::Bool* ParallelizedDomainPredicate::getPredicate( - const kir::Expr* expr, +Bool* ParallelizedDomainPredicate::getPredicate( + const Expr* expr, const std::vector& loops) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - auto pred_map = getPredicateMap(expr, loops); - kir::Val* pred = ir_builder.trueVal(); + Val* pred = GpuLower::current()->kernel()->trueVal(); for (auto pt : kParallelTypeThreads) { auto pred_info_it = pred_map.find(pt); if (pred_info_it != pred_map.end()) { const auto& pred_info = pred_info_it->second; auto tid_pred = pred_info.getPredicate(); - pred = ir_builder.andExpr(pred, tid_pred); + pred = SimplifyingIrBuilder::andExpr(pred, tid_pred); } } if (pred) { - return pred->as(); + return pred->as(); } else { return nullptr; } @@ -256,61 +237,55 @@ UnswitchPredicateKey::UnswitchPredicateKey() // concrete domains are used to uniquely collect all necessary // unswitch predicates. UnswitchPredicateKey::UnswitchPredicateKey( - IterDomain* predicated_concrete_id, - const ReferenceTensor& reference) + IterDomain* predicated_consumer_id, + TensorView* consumer_tv, + IterDomain* predicated_concrete_id) : predicated_concrete_id_(predicated_concrete_id) { // Initialize the parallelized domain map for (auto pt : kParallelTypeThreads) { parallel_concrete_ids_.insert({pt, nullptr}); } - // The id parameter is a concrete domain. Needs to find the - // corresponding reference domain to find leaf domains that are - // parallelized. - IterDomain* predicated_ref_id = - reference.concrete_to_id.at(predicated_concrete_id_); - TensorDomain* ref_td = reference.domain; - - std::vector all_parallelized_ref_leaf_ids; + std::vector all_parallelized_consumer_leaf_ids; std::copy_if( - ref_td->domain().begin(), - ref_td->domain().end(), - std::back_inserter(all_parallelized_ref_leaf_ids), + consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end(), + std::back_inserter(all_parallelized_consumer_leaf_ids), [](IterDomain* x) { return isParallelTypeThread(x->getParallelType()); }); - // If the reference is not parallelized at all, no need to + // If the consumer domais are not parallelized at all, no need to // differentiate keys based on how the predicated id is parallelized - if (all_parallelized_ref_leaf_ids.empty()) { + if (all_parallelized_consumer_leaf_ids.empty()) { return; } - // All domains that are parallelized descendants of predicated_ref_id - auto all_parallelized_ref_ids = DependencyCheck::getAllValsBetween( - {predicated_ref_id}, all_parallelized_ref_leaf_ids); + // All domains that are parallelized descendants of predicated_consumer_id + auto all_parallelized_consumer_ids = DependencyCheck::getAllValsBetween( + {predicated_consumer_id}, all_parallelized_consumer_leaf_ids); // Just pick leaf domains - std::vector parallelized_ref_leaf_ids; + std::vector parallelized_consumer_leaf_ids; std::copy_if( - ref_td->domain().begin(), - ref_td->domain().end(), - std::back_inserter(parallelized_ref_leaf_ids), + consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end(), + std::back_inserter(parallelized_consumer_leaf_ids), [&](IterDomain* x) { return std::find( - all_parallelized_ref_ids.begin(), - all_parallelized_ref_ids.end(), - x) != all_parallelized_ref_ids.end(); + all_parallelized_consumer_ids.begin(), + all_parallelized_consumer_ids.end(), + x) != all_parallelized_consumer_ids.end(); }); - if (parallelized_ref_leaf_ids.empty()) { - // None of the parallelized leaf domains are derived from predicated_ref_id + if (parallelized_consumer_leaf_ids.empty()) { + // None of the parallelized leaf domains are derived from + // predicated_consumer_id return; } // Find the corresponding concrete id for each parallel type - for (auto ref_leaf : parallelized_ref_leaf_ids) { - auto pt = ref_leaf->getParallelType(); - auto it = reference.id_to_concrete.find(ref_leaf); - TORCH_INTERNAL_ASSERT(it != reference.id_to_concrete.end()); - auto concrete_leaf = it->second; + for (auto consumer_leaf : parallelized_consumer_leaf_ids) { + auto pt = consumer_leaf->getParallelType(); + auto concrete_leaf = + GpuLower::current()->caIndexMap().getConcreteMappedID(consumer_leaf); parallel_concrete_ids_.at(pt) = concrete_leaf; } } @@ -344,19 +319,18 @@ std::size_t UnswitchPredicateKeyHash::operator()( return h; }; -kir::Bool* PredicateCompute::getInlinePredicate( - const kir::Expr* expr, +Bool* PredicateCompute::getInlinePredicate( + const Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, PredicateType pred_type) { FUSER_PERF_SCOPE("GpuLower::Lower::getInlinePredicate"); const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); // If outputs are registers, no need to predicate for threads if (isOutputLocal(expr)) { - thread_pred = ir_builder.trueVal(); + thread_pred = gpu_lower->kernel()->trueVal(); } if (loops.empty()) { @@ -364,8 +338,8 @@ kir::Bool* PredicateCompute::getInlinePredicate( return thread_pred; } - auto out_tv = ir_utils::getTVOutput(expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output"); + auto out_tv = ir_utils::getTvOutput(expr); + TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); if (gpu_lower->predicateElimination().canOmitPredicate(expr)) { return thread_pred; @@ -376,7 +350,7 @@ kir::Bool* PredicateCompute::getInlinePredicate( out_tv, loops, nullptr, pred_type == PredicateType::Padding) .first; - std::vector preds; + std::vector preds; // When pred_type is ReductionWrite, filter out predicates for // reduction axes. For blockReduce, this is necessary when reduction @@ -388,7 +362,7 @@ kir::Bool* PredicateCompute::getInlinePredicate( bool non_zero_start_found = false; for (const auto& pred_info : pred_info_vec) { if (pred_type == PredicateType::ReductionWrite) { - const auto& consumer_ids = pred_info.consumerIds(); + const auto& consumer_ids = pred_info.rootIds(); bool pred_for_reduction_axis = false; for (auto consumer_id : consumer_ids) { if (consumer_id->isReduction()) { @@ -404,21 +378,15 @@ kir::Bool* PredicateCompute::getInlinePredicate( continue; } } - for (auto pred : pred_info.startPredicates()) { - TORCH_INTERNAL_ASSERT(pred != nullptr); - preds.push_back(pred); - } - for (auto pred : pred_info.stopPredicates()) { - TORCH_INTERNAL_ASSERT(pred != nullptr); - preds.push_back(pred); - } + preds.push_back(pred_info.startPredicate()); + preds.push_back(pred_info.stopPredicate()); } // When generating a predicate for blockReduce writes and not for // gridReduce, if all reduction axes start with zero, we can just // use the same predicate for reads. nullptr is returned then. if (pred_type == PredicateType::ReductionWrite && !non_zero_start_found && - !out_tv->fuserTv()->domain()->hasGridReduction()) { + !out_tv->domain()->hasGridReduction()) { return nullptr; } @@ -433,35 +401,33 @@ kir::Bool* PredicateCompute::getInlinePredicate( } if (preds.empty()) { - return ir_builder.trueVal(); + return GpuLower::current()->kernel()->trueVal(); } - kir::Val* cond = preds[0]; + Val* cond = preds[0]; for (const auto i : c10::irange(1, preds.size())) { - cond = ir_builder.andExpr(cond, preds[i]); + cond = SimplifyingIrBuilder::andExpr(cond, preds[i]); } - return cond->as(); + return cond->as(); } -kir::Bool* UnswitchPredicate::get( +Bool* UnswitchPredicate::get( const std::vector& outer_loops, kir::ForLoop* unrolled_loop) { FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::get"); - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - UnswitchPredicate up(outer_loops, unrolled_loop); - kir::Val* unswitch_pred = ir_builder.trueVal(); + Val* unswitch_pred = GpuLower::current()->kernel()->trueVal(); for (auto pred : up.predicates_) { - unswitch_pred = ir_builder.andExpr(unswitch_pred, pred); + unswitch_pred = SimplifyingIrBuilder::andExpr(unswitch_pred, pred); } - return unswitch_pred->as(); + return unswitch_pred->as(); } -void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { +void UnswitchPredicate::predicateOn(Expr* tv_expr) { FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::predicateOn"); if (for_loops_.empty()) { @@ -469,14 +435,12 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { } const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr)) { return; } - auto out_tv = ir_utils::getTVOutput(tv_expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output"); + auto out_tv = ir_utils::getTvOutput(tv_expr); + TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); auto ref_pred_info = Index::getReferenceRootPredicates( out_tv, for_loops_, unrolled_loop_, false); @@ -491,10 +455,8 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { // predicates are generated in the finalize function. for (const auto& pred_info : ref_pred_info.first) { - if (pred_info.startPredicates().empty() && - pred_info.stopPredicates().empty()) { - continue; - } + TORCH_INTERNAL_ASSERT(pred_info.startPredicate() != nullptr); + TORCH_INTERNAL_ASSERT(pred_info.stopPredicate() != nullptr); const auto& root_ids = pred_info.rootIds(); @@ -505,13 +467,14 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { bool first_key_set = false; for (auto root_id : root_ids) { - auto kir_root_id = gpu_lower->lowerValue(root_id)->as(); + auto concrete_root_id = + gpu_lower->caIndexMap().getConcreteMappedID(root_id); - if (kir_root_id->isBroadcast()) { + if (root_id->isBroadcast()) { continue; } - UnswitchPredicateKey key(root_id, reference); + UnswitchPredicateKey key(root_id, out_tv, concrete_root_id); auto inserted = predicated_keys_.insert(key).second; add_pred = add_pred || inserted; @@ -573,14 +536,14 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { // start and stop offsets. if (merged_pred_it != pending_predicates_.end()) { mergeUnswitchPredicateOffsets( - pred_info.startPredicates(), - pred_info.startOffsets(), + pred_info.startPredicate(), + pred_info.startOffset(), merged_pred_it->start, true); mergeUnswitchPredicateOffsets( - pred_info.stopPredicates(), - pred_info.stopOffsets(), + pred_info.stopPredicate(), + pred_info.stopOffset(), merged_pred_it->stop, false); } @@ -613,7 +576,7 @@ void UnswitchPredicate::openLoop(kir::ForLoop* fl) { for_loops_.push_back(fl); for (auto expr : fl->body().exprs()) { - if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) { + if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) { predicateOn(expr); } else if (auto ite = dynamic_cast(expr)) { openIte(ite); @@ -630,7 +593,7 @@ void UnswitchPredicate::openIte(kir::IfThenElse* ite) { // only expand the ite thenBody for (auto expr : ite->thenBody().exprs()) { - if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) { + if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) { predicateOn(expr); } else if (auto ite = dynamic_cast(expr)) { openIte(ite); @@ -641,7 +604,6 @@ void UnswitchPredicate::openIte(kir::IfThenElse* ite) { } void UnswitchPredicate::finalize() { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); for (const auto& merged_pred : pending_predicates_) { const auto& start_info = merged_pred.start; if (start_info.static_pred) { @@ -661,12 +623,10 @@ void UnswitchPredicate::finalize() { } void UnswitchPredicate::mergeUnswitchPredicateOffsets( - const std::vector& predicates, - const std::vector& offsets, + Bool* predicate, + Val* offset, MergedPredicates::Info& merged_predicate_info, bool is_start) { - TORCH_INTERNAL_ASSERT(predicates.size() == offsets.size()); - auto is_more_restrictive = [&is_start](int64_t new_val, int64_t current_val) { if (is_start) { return new_val < current_val; @@ -675,25 +635,21 @@ void UnswitchPredicate::mergeUnswitchPredicateOffsets( } }; - for (const auto i : c10::irange(predicates.size())) { - auto pred = predicates.at(i); - auto offset = offsets.at(i); - auto offset_int = dynamic_cast(offset); - // If it's a static predicate, replace the current one if it's - // more restrictive. If it's dynamic, just adds it to the dynamic - // predicate list. - if (offset_int && offset_int->isConst()) { - auto offset_const = offset_int->value().value(); - auto& static_pred = merged_predicate_info.static_pred; - auto& static_offset = merged_predicate_info.static_offset; - if (static_pred == nullptr || - is_more_restrictive(offset_const, static_offset)) { - static_pred = pred; - static_offset = offset_const; - } - } else { - merged_predicate_info.dynamic_preds.push_back(pred); + auto offset_int = dynamic_cast(offset); + // If it's a static predicate, replace the current one if it's + // more restrictive. If it's dynamic, just adds it to the dynamic + // predicate list. + if (offset_int && offset_int->isConst()) { + auto offset_const = offset_int->value().value(); + auto& static_pred = merged_predicate_info.static_pred; + auto& static_offset = merged_predicate_info.static_offset; + if (static_pred == nullptr || + is_more_restrictive(offset_const, static_offset)) { + static_pred = predicate; + static_offset = offset_const; } + } else { + merged_predicate_info.dynamic_preds.push_back(predicate); } } diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.h b/torch/csrc/jit/codegen/cuda/predicate_compute.h index 989bffb3bd1..c6412671e43 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.h +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.h @@ -16,10 +16,10 @@ class PredicateCompute { // ignore_internal_syncthread_ops will prevent creation of predicates on // block/grid broadcast/reduce as these have syncthread calls within them // so all threads need to execute the function. - static kir::Bool* getInlinePredicate( - const kir::Expr* expr, + static Bool* getInlinePredicate( + const Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, PredicateType pred_type); }; @@ -40,31 +40,31 @@ class ParallelizedDomainPredicate { explicit PredicateInfo(ParallelType pt) : pt_(pt) {} //! Adds a domain that is parallized by the same paralell type - bool addDomain(kir::IterDomain* id); + bool addDomain(IterDomain* id); - const std::vector& ids() const { + const std::vector& ids() const { return ids_; } //! Generates a predicate Val from predicate information - kir::Bool* getPredicate() const; + Bool* getPredicate() const; private: ParallelType pt_; //! Domains parallelized by the same parallel type - std::vector ids_; + std::vector ids_; }; //! Returns a predicate Val for parallelied domains of an expression. - static kir::Bool* getPredicate( - const kir::Expr* expr, + static Bool* getPredicate( + const Expr* expr, const std::vector& loops); //! Returns predicate information for parallelied domains of an //! expression. static std::unordered_map getPredicateMap( - const kir::Expr* expr, + const Expr* expr, const std::vector& loops, kir::ForLoop* unswitched_loop = nullptr); }; @@ -80,8 +80,9 @@ class UnswitchPredicateKey { UnswitchPredicateKey(); UnswitchPredicateKey( - IterDomain* predicated_concrete_id, - const ReferenceTensor& reference); + IterDomain* predicated_consumer_id, + TensorView* consumer_tv, + IterDomain* predicated_concrete_id); bool operator==(const UnswitchPredicateKey& other) const { return predicated_concrete_id_ == other.predicated_concrete_id_ && @@ -121,7 +122,7 @@ struct UnswitchPredicateKeyHash { class TORCH_CUDA_CU_API UnswitchPredicate { public: - static kir::Bool* get( + static Bool* get( const std::vector& outer_loops, kir::ForLoop* unrolled_loop); @@ -132,11 +133,11 @@ class TORCH_CUDA_CU_API UnswitchPredicate { struct Info { //! Most restrictive static predicate. Nullptr if no static //! predicate found. - kir::Bool* static_pred = nullptr; + Bool* static_pred = nullptr; //! The offset value of static_pred int64_t static_offset = 0; //! List of dynamic predicates. - std::vector dynamic_preds; + std::vector dynamic_preds; }; UnswitchPredicateKey predicate_key; Info start; @@ -147,7 +148,7 @@ class TORCH_CUDA_CU_API UnswitchPredicate { std::vector outer_loops, kir::ForLoop* unrolled_loop); - void predicateOn(kir::Expr*); + void predicateOn(Expr*); void openLoop(kir::ForLoop*); @@ -160,8 +161,8 @@ class TORCH_CUDA_CU_API UnswitchPredicate { //! static, only pick the most restrictive one, e.g., the one with the //! minimum offset for the start predication. void mergeUnswitchPredicateOffsets( - const std::vector& predicates, - const std::vector& offsets, + Bool* predicate, + Val* offset, MergedPredicates::Info& merged_predicate_info, bool is_start); @@ -181,7 +182,7 @@ class TORCH_CUDA_CU_API UnswitchPredicate { parallelized_dom_predicates_; //! The predicates that have been generated. - std::vector predicates_; + std::vector predicates_; std::vector for_loops_; diff --git a/torch/csrc/jit/codegen/cuda/reference_tensor.h b/torch/csrc/jit/codegen/cuda/reference_tensor.h index 2220831dc09..07c83bb6ed7 100644 --- a/torch/csrc/jit/codegen/cuda/reference_tensor.h +++ b/torch/csrc/jit/codegen/cuda/reference_tensor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index ddb92371baa..b48c6b00b3a 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -196,7 +196,7 @@ UnmappableReductionDomains::UnmappableReductionDomains() { namespace { -//! Find all domains that a given domain is depeendent on +//! Find all domains that a given domain is dependent on class FindInputDomains : BackwardVisitor { private: FindInputDomains(TensorView* tv, const IterDomain* id) @@ -661,6 +661,58 @@ void ComputeAtRootDomainMapBuilder::setMapped( root_map_.eq_set_.join(producer, consumer); } +void ComputeAtRootDomainMapBuilder::setInvalid( + const DomainKey& key1, + const DomainKey& key2) { + invalid_mappings_.emplace_back(key1, key2); +} + +bool ComputeAtRootDomainMapBuilder::isInvalid( + const std::vector& domains) const { + // First, collect all invalid mappings for each of the keys in domains + DomainKeyMap invalid_key_map; + for (const auto& key : domains) { + DomainKeySet invalid_keys; + for (const auto& invalid_pair : invalid_mappings_) { + if (root_map_.canMap(key, invalid_pair.first)) { + invalid_keys.insert(invalid_pair.second); + } else if (root_map_.canMap(key, invalid_pair.second)) { + invalid_keys.insert(invalid_pair.first); + } + } + invalid_key_map.emplace(key, invalid_keys); + } + + // Next, check if any pair is invalid to map. + const auto num_keys = domains.size(); + for (const auto i : c10::irange(num_keys)) { + const auto& key_i = domains[i]; + // If no invalid keys found for key_i, it can be skipped. + const auto invalid_key_map_it = invalid_key_map.find(key_i); + if (invalid_key_map_it == invalid_key_map.end()) { + continue; + } + + // Set of keys that are invalid to be mapped with key_i. + const DomainKeySet& invalid_keys_for_i = invalid_key_map_it->second; + + // If any other key in domains is identified mappable with any of + // the keys in this set, the mapping with key_i is invalid. + for (const auto j : c10::irange(i + 1, num_keys)) { + const auto& key_j = domains[j]; + if (std::any_of( + invalid_keys_for_i.begin(), + invalid_keys_for_i.end(), + [&](const auto& invalid_key_for_i) { + return root_map_.canMap(key_j, invalid_key_for_i); + })) { + return true; + } + } + } + return false; +} + void ComputeAtRootDomainMapBuilder::setMaybeMapped( const TensorDomain* producer_td, const IterDomain* producer_id, @@ -853,9 +905,11 @@ bool ComputeAtRootDomainMapBuilder::mapAllConsumers( // All entries in key_set must be equivalent with each other. TORCH_INTERNAL_ASSERT(consumer_set.size() > 0); bool consistent = safeToMap(consumer_set); - if (consistent) { - for (const auto pending_consumer : consumer_set) { + for (const auto pending_consumer : consumer_set) { + if (consistent) { setMapped(producer_key, pending_consumer); + } else { + setInvalid(producer_key, pending_consumer); } } // This entry should never be used again, so remove it. @@ -931,6 +985,10 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { !map_through_reduction_) { return false; } + // Make sure mapping these domains won't cause any invalid mapping + if (isInvalid(unique_domains)) { + return false; + } return true; } diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index 23ada0fb120..5156dc604f1 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -5,7 +5,7 @@ #include #include -#include +#include namespace torch { namespace jit { @@ -110,7 +110,7 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { const TensorView* consumer_tv_ = nullptr; }; -std::string toString(const PairwiseRootDomainMap& root_map); +TORCH_CUDA_CU_API std::string toString(const PairwiseRootDomainMap& root_map); //! Represents an iteration domain of a TensorDomain. Only used for //! root domain mapping. @@ -206,7 +206,7 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor { //! This will create mappings between i0, i2 and i4. class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { friend class ComputeAtRootDomainMapBuilder; - friend std::string toString(const ComputeAtRootDomainMap&); + friend TORCH_CUDA_CU_API std::string toString(const ComputeAtRootDomainMap&); public: //! Builds a mapping table by analyzing the current @@ -327,7 +327,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { std::unordered_set window_axes_; }; -std::string toString(const ComputeAtRootDomainMap& root_map); +TORCH_CUDA_CU_API std::string toString(const ComputeAtRootDomainMap& root_map); //! Create a DisjointSet of root IterDomains by traversing the //! current fusion entirely. IterDomains that can be mapped each @@ -347,6 +347,12 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder //! Set a pair of producer-consumer domain keys as mappable void setMapped(const DomainKey& producer, const DomainKey& consumer); + //! Records two domains are invalid to map + void setInvalid(const DomainKey& key1, const DomainKey& key2); + + //! Check if no pair of domains is invalid to map + bool isInvalid(const std::vector& domains) const; + //! Track a pair of producer-consumer domains as potentially mappable. Inserts //! entries into pending_map_, but does not add anything into the root_map_ //! (added when handle is called on a TensorView). Maybe mapped will, however, @@ -415,10 +421,13 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder private: ComputeAtRootDomainMap& root_map_; - //! Keep track of what we want to try and map. Set in attemptToProveId. + //! Keep track of what we want to try and map DomainKeyMap pending_map_; std::unordered_set visited_; + //! Helper class to find invalid mappings due to reductions UnmappableReductionDomains incompatible_domains_; + //! Running vector of domain pairs that are invalid to map + std::vector> invalid_mappings_; //! Disable UnmappableReductions check, should //! always be false for compute_at use cases diff --git a/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu b/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu index ed366132689..fcbc98e7818 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu @@ -41,10 +41,8 @@ __device__ void sync() { // threads have incremented the counter. while (local_sync_counter < next && old < local_sync_counter) { #if __CUDA_ARCH__ >= 700 - __nanosleep(backoff); -#else - // __nanosleep is not available for sm < 70 - assert(false); + // __nanosleep only available on compute capability 7.0 or higher + __nanosleep(backoff); // avoids busy waiting #endif if (backoff < backoff_max) { backoff *= 2; diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu index a75d0d5904a..83382f4704c 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu @@ -69,7 +69,7 @@ template < typename Func> __device__ void gridReduceLastBlock( T& out, - const T* in, + const volatile T* in, const nvfuser_index_t grid_reduction_segment_size, // Number of reductions across // grid reduce dimensions diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu index 0ccb07142aa..a134bd81c2d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu @@ -54,10 +54,8 @@ __device__ void sync(int64_t& semaphore, const uint64_t& segment_size) { // Put a sleep here so we have some breaks in probing the global // semaphore, giving a better chance for other warps/blocks to catch up. #if __CUDA_ARCH__ >= 700 - __nanosleep(200); -#else - // __nanosleep is not available for sm < 70 - assert(false); + // __nanosleep only available on compute capability 7.0 or higher + __nanosleep(200); // avoids busy waiting #endif } } diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu index 61dccb4dff2..02fd8bf8777 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -279,3 +279,19 @@ template <> double pow(double a, double b) { return ::pow(a, b); } + +float pow(float a, int b) { + return pow(a, (float)b); +} + +double pow(double a, int b) { + return pow(a, (double)b); +} + +float pow(float a, int64_t b) { + return pow(a, (float)b); +} + +double pow(double a, int64_t b) { + return pow(a, (double)b); +} diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu index aab51a8f158..ac4f2069b3b 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu @@ -19,3 +19,13 @@ struct Tensor { T* data; }; + +// Specialization for 0-dim case that's easy to pass in a CPU based tensor. +template +struct CpuScalarTensor { + __device__ T& operator[](int) { + return data; + }; + + T data; +}; diff --git a/torch/csrc/jit/codegen/cuda/runtime/welford.cu b/torch/csrc/jit/codegen/cuda/runtime/welford.cu index 07d848c55f2..c3b09d82b74 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/welford.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/welford.cu @@ -8,8 +8,8 @@ __inline__ __device__ void welfordCombine( T& a_avg, T& a_M2, TN& a_N, - const T& b_avg, - const T& b_M2, + const T b_avg, + const T b_M2, TN b_N) { if (b_N == 0) { return; @@ -183,9 +183,9 @@ __device__ void gridWelfordLastBlock( T& out_avg, T& out_M2, TN& out_N, - const T* in_avg, - const T* in_M2, - const TN* in_N, + const volatile T* in_avg, + const volatile T* in_M2, + const volatile TN* in_N, const nvfuser_index_t grid_reduction_segment_size, // Number of reductions across // grid reduce dimensions @@ -345,9 +345,9 @@ __device__ void gridWelford( out_avg, out_M2, out_N, - (T*)work_buf_avg, - (T*)work_buf_M2, - (TN*)work_buf_N, + work_buf_avg, + work_buf_M2, + work_buf_N, grid_reduction_segment_size, block_reduction_segment_size, shared_buf_avg, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index b856d83ac92..8aa3081fcc6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -43,6 +43,9 @@ ReductionParams innerPersistentHeuristic( // Set some targets for parallelization const int64_t n_elems = total_reduction_numel * total_iteration_numel; + const int64_t outer_reduction_numel = + total_reduction_numel / inner_most_dimension_numel; + // WARNING: At some point we may want to generate heuristics for another // device that is not the current device. const int64_t device_max_threads_per_multiprocessor = @@ -228,7 +231,7 @@ ReductionParams innerPersistentHeuristic( bdimz = std::min( std::min( std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1), - ceilDiv(total_reduction_numel, inner_most_dimension_numel)), + outer_reduction_numel), scheduler_utils::z_block_limit); // If 3D doesn't fill out the threads, adjust to add to bdimy @@ -251,15 +254,13 @@ ReductionParams innerPersistentHeuristic( bdimz = std::min( std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1), - ceilDiv(total_reduction_numel, inner_most_dimension_numel)); + outer_reduction_numel); bdimy = std::min( std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1), max_multi_reduction_factor); } - godim = ceilDiv(total_iteration_numel, bdimy); - bool vectorize = false; // Move unrolling factor into vectorization upto vectorization limit. @@ -275,8 +276,7 @@ ReductionParams innerPersistentHeuristic( if (inner_reduction_unroll_factor < max_unroll) { outer_reduction_unroll_factor = std::min( ceilDiv(max_unroll, inner_reduction_unroll_factor), - ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), bdimz)); + ceilDiv(outer_reduction_numel, bdimz)); } godim = ceilDiv(total_iteration_numel, bdimy); @@ -304,9 +304,8 @@ ReductionParams innerPersistentHeuristic( while (outer_reduction_unroll_factor < max_unroll && batches_per_block_outer_reduction >= 2) { outer_reduction_unroll_factor *= 2; - batches_per_block_outer_reduction = roundUpPow2Or8(ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), - bdimz * outer_reduction_unroll_factor)); + batches_per_block_outer_reduction = roundUpPow2Or8( + ceilDiv(outer_reduction_numel, bdimz * outer_reduction_unroll_factor)); } // If we haven't gotten to the max_unroll case, try to take it out of the @@ -334,7 +333,7 @@ ReductionParams innerPersistentHeuristic( inner_most_dimension_numel, inner_reduction_unroll_factor * batches_per_block_inner_reduction); bdimz = ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), + outer_reduction_numel, outer_reduction_unroll_factor * batches_per_block_outer_reduction); // Try moving persistent buffer factors into threads until we have too many @@ -368,9 +367,8 @@ ReductionParams innerPersistentHeuristic( batches_per_block_outer_reduction = roundUpPow2Or8(batches_per_block_outer_reduction / 2); bdimz = ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), + outer_reduction_numel, batches_per_block_outer_reduction * outer_reduction_unroll_factor); - continue; } break; @@ -410,13 +408,18 @@ ReductionParams innerPersistentHeuristic( pad_bdimx = pad_bdimx && bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel; + // Will be used once supporting inter-block persistence + int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimy = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimz = LaunchParams::UNINITIALIZED_VAL; + ReductionParams rparams; rparams.persistent_kernel = true; rparams.fastest_dim = true; // Inner reduction domain - rparams.cross_block_inner_reduce = true; + rparams.cross_block_inner_reduction = true; rparams.block_dim_inner_reduction = ParallelType::TIDx; rparams.pad_inner_reduction_to_warp = pad_bdimx; rparams.batches_per_block_inner_reduction = batches_per_block_inner_reduction; @@ -432,8 +435,15 @@ ReductionParams innerPersistentHeuristic( if (rparams.multiple_reds_per_blk) { rparams.block_dim_iter_dom = ParallelType::TIDy; } - rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = godim > scheduler_utils::x_grid_limit; + + if (godim > 1) { + rparams.grid_dim_iter_dom = ParallelType::BIDx; + if (godim > scheduler_utils::x_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimx = scheduler_utils::x_grid_limit; + } + } + if (iter_unroll_factor > 1) { rparams.unroll_iter_dom = true; rparams.unroll_factor_iter_dom = iter_unroll_factor; @@ -445,15 +455,15 @@ ReductionParams innerPersistentHeuristic( rparams.batches_per_block_outer_reduction = batches_per_block_outer_reduction; rparams.block_dim_outer_reduction = ParallelType::TIDz; - rparams.cross_block_outer_reduce = true; + rparams.cross_block_outer_reduction = true; rparams.unroll_outer_reduction = outer_reduction_unroll_factor > 1; rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor; } rparams.lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, + gdimx, + gdimy, + gdimz, LaunchParams::UNINITIALIZED_VAL, bdimy, LaunchParams::UNINITIALIZED_VAL); @@ -697,8 +707,8 @@ ReductionParams OuterPersistentHeuristic( rparams.persistent_kernel = true; rparams.fastest_dim = false; - rparams.cross_block_inner_reduce = true; - rparams.cross_grid_inner_reduce = false; + rparams.cross_block_inner_reduction = true; + rparams.cross_grid_inner_reduction = false; rparams.multiple_reds_per_blk = bdimx > 1; if (rparams.multiple_reds_per_blk) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index fb478f1110f..fb465b287e6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -391,6 +391,12 @@ class DomainMap { return nullptr; } + static bool hasReferenceTensorView(Fusion* fusion) { + FusionGuard fg(fusion); + DomainMap domain_map(fusion); + return domain_map.findReferenceTensorView() != nullptr; + } + private: // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input. @@ -417,7 +423,8 @@ class DomainMap { // Get concrete IDs for input root or rfactor domain std::unordered_set in_concrete_ids; for (auto in_id : input_tv->getMaybeRFactorDomain()) { - if (!in_id->isBroadcast() && !in_id->isReduction()) { + if (!ca_index_map_.getConcreteMappedID(in_id)->isBroadcast() && + !in_id->isReduction()) { in_concrete_ids.insert(ca_index_map_.getConcreteMappedID(in_id)); } } @@ -491,6 +498,10 @@ class DomainMap { } // namespace +bool hasReferenceTensorView(Fusion* fusion) { + return DomainMap::hasReferenceTensorView(fusion); +} + // TODO: Inline intermediate operations (avoid inlining unrolled/vectorized // input/output caches) void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { @@ -503,7 +514,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // maybe has_reduction for scheduling should be done on a per output tensor // basis. TORCH_INTERNAL_ASSERT( - !fusion->hasReduction(), "This scheduler only handles pointwise ops."); + ir_utils::getReductionOps(fusion).empty(), + "This scheduler only handles pointwise ops."); // For intermediate outputs, apply cache_fork auto outs = fusion->outputs(); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h index cb626556579..57b77bb20cc 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h @@ -31,6 +31,11 @@ TORCH_CUDA_CU_API LaunchParams schedulePointwise( Fusion* fusion, const at::ArrayRef& runtime_inputs); +//! Utility for canSchedule interface to check if this fusion has +//! a fully broadcasted reference tensor, which is necessary for +//! the pointwise scheduler. +bool hasReferenceTensorView(Fusion* fusion); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index b0d4f12b921..088968b0890 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -334,9 +334,9 @@ ReductionParams innerReductionHeuristic( ReductionParams rparams; rparams.fastest_dim = true; - rparams.cross_block_inner_reduce = true; + rparams.cross_block_inner_reduction = true; rparams.block_dim_inner_reduction = ParallelType::TIDx; - rparams.cross_grid_inner_reduce = gridim > 1; + rparams.cross_grid_inner_reduction = gridim > 1; rparams.multiple_reds_per_blk = bdimy > 1; bool pad_bdimx = bdimx > 16 && bdimx * bdimy < @@ -359,7 +359,9 @@ ReductionParams innerReductionHeuristic( rparams.vectorize_inner_reduction = vectorize; } - rparams.block_dim_iter_dom = ParallelType::TIDy; + if (rparams.multiple_reds_per_blk) { + rparams.block_dim_iter_dom = ParallelType::TIDy; + } if (iter_unroll_factor > 1) { rparams.unroll_iter_dom = true; rparams.unroll_factor_iter_dom = iter_unroll_factor; @@ -368,10 +370,10 @@ ReductionParams innerReductionHeuristic( rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel; // Outer reduction domain if (rparams.schedule_3D) { - rparams.cross_grid_outer_reduce = grodim > 1; + rparams.cross_grid_outer_reduction = grodim > 1; if (bdimz > 1) { rparams.block_dim_outer_reduction = ParallelType::TIDz; - rparams.cross_block_outer_reduce = true; + rparams.cross_block_outer_reduction = true; } rparams.unroll_outer_reduction = outer_reduction_unroll_factor > 1; rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor; @@ -385,39 +387,40 @@ ReductionParams innerReductionHeuristic( // gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in // case it's larger than gdimy can hold, as not doing so can thrash the cache. - if (rparams.cross_grid_inner_reduce) { + if (rparams.cross_grid_inner_reduction) { rparams.grid_dim_inner_reduction = ParallelType::BIDx; - gdimx = gridim; - rparams.split_grid_dim_inner_reduction = - gdimx > scheduler_utils::x_grid_limit; + rparams.split_grid_dim_inner_reduction = true; + gdimx = std::min(gridim, scheduler_utils::x_grid_limit); rparams.grid_dim_iter_dom = ParallelType::BIDy; - gdimy = godim; - rparams.split_grid_dim_iter_dom = gdimy > scheduler_utils::y_grid_limit; + if (godim > scheduler_utils::y_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimy = std::min(godim, scheduler_utils::y_grid_limit); + } } else { - gdimx = godim; rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; + if (gdimx > scheduler_utils::x_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimx = godim; + } } - if (rparams.cross_grid_outer_reduce) { - if (rparams.cross_block_inner_reduce) { - gdimz = grodim; + if (rparams.cross_grid_outer_reduction) { + if (rparams.cross_block_inner_reduction) { rparams.grid_dim_outer_reduction = ParallelType::BIDz; + gdimz = std::min(grodim, scheduler_utils::z_grid_limit); + rparams.split_grid_dim_outer_reduction = true; } else { - gdimy = grodim; rparams.grid_dim_outer_reduction = ParallelType::BIDy; + gdimy = std::min(grodim, scheduler_utils::y_grid_limit); + rparams.split_grid_dim_outer_reduction = true; } } rparams.lparams = LaunchParams( - rparams.grid_dim_iter_dom == ParallelType::BIDx - ? LaunchParams::UNINITIALIZED_VAL - : gdimx, - rparams.grid_dim_iter_dom == ParallelType::BIDy - ? LaunchParams::UNINITIALIZED_VAL - : gdimy, + gdimx, + gdimy, gdimz, bdimx, bdimy > 1 ? bdimy : LaunchParams::UNINITIALIZED_VAL, @@ -441,12 +444,13 @@ ReductionParams innerReductionHeuristic( // schedule if (rparams.schedule_3D) { if (rparams.multiple_reds_per_blk && - (rparams.cross_grid_inner_reduce || rparams.cross_grid_outer_reduce)) { + (rparams.cross_grid_inner_reduction || + rparams.cross_grid_outer_reduction)) { if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { std::cerr << "\n===== UNSUPPORTED REDUCTION HEURISTIC ========\n"; std::cerr << rparams.multiple_reds_per_blk << ", " << rparams.unroll_inner_reduction << ", " - << rparams.cross_grid_inner_reduce << std::endl; + << rparams.cross_grid_inner_reduction << std::endl; } return innerReductionHeuristic( total_reduction_numel, @@ -534,9 +538,9 @@ ReductionParams OuterReductionHeuristic( // domain for this // Blocks for reductions - int64_t gdimy = 1; + int64_t grdim = 1; // Blocks for outputs - int64_t gdimx = 1; + int64_t gidim = 1; // Threads for reduction int64_t bdimy = 1; @@ -597,11 +601,11 @@ ReductionParams OuterReductionHeuristic( std::min(max_unroll, ceilDiv(total_reduction_numel, bdimy)); // Go cross grid - gdimy = ceilDiv( + grdim = ceilDiv( ceilDiv(total_reduction_numel, bdimy * inner_reduction_unroll_factor), (int64_t)4); - gdimx = ceilDiv(total_iteration_numel, bdimx * iter_unroll_factor); + gidim = ceilDiv(total_iteration_numel, bdimx * iter_unroll_factor); // Clang tidy constexpr int64_t kEight = 8; @@ -611,13 +615,13 @@ ReductionParams OuterReductionHeuristic( if (ceilDiv(total_reduction_numel, bdimy * inner_reduction_unroll_factor) >= kThirtyTwo) { // Many reduction elements, go cross grid - int64_t min_gdimy = 1; - if (gdimy > 1) { + int64_t min_grdim = 1; + if (grdim > 1) { // already cross grid, don't go below target or what was already set - min_gdimy = std::min(gdimy, ceilDiv(target_blocks, gdimx)); + min_grdim = std::min(grdim, ceilDiv(target_blocks, gidim)); } - gdimy = std::max( - min_gdimy, + grdim = std::max( + min_grdim, ceilDiv( ceilDiv( total_reduction_numel, bdimy * inner_reduction_unroll_factor), @@ -625,33 +629,33 @@ ReductionParams OuterReductionHeuristic( // Don't go too far above number of threads in a block since that's how many // threads are available to do final reduction iteration // This is good! - gdimy = std::min(gdimy, bdimx * bdimy * kEight); + grdim = std::min(grdim, bdimx * bdimy * kEight); } // Try to do some cleanup of ragged waves on device if ( // If we have less than 8 waves of blocks - gdimy * gdimx < device_multiprocessor_count * kEight && + grdim * gidim < device_multiprocessor_count * kEight && // And we don't have an even divisible number of blocks - (gdimy * gdimx) % device_multiprocessor_count != 0 && + (grdim * gidim) % device_multiprocessor_count != 0 && // And we have more than one wave - gdimy * gdimx > device_multiprocessor_count) { + grdim * gidim > device_multiprocessor_count) { // round waves down auto waves = - std::max((gdimx * gdimy) / device_multiprocessor_count, (int64_t)1); - auto new_gdimy = - std::max((waves * device_multiprocessor_count) / gdimx, (int64_t)1); + std::max((gidim * grdim) / device_multiprocessor_count, (int64_t)1); + auto new_grdim = + std::max((waves * device_multiprocessor_count) / gidim, (int64_t)1); if ( - // If difference is less than 25% of the original gdimy - (new_gdimy - gdimy) * 4 < gdimy && + // If difference is less than 25% of the original grdim + (new_grdim - grdim) * 4 < grdim && // and difference is less than 25% of the original number of blocks - ((new_gdimy * gdimx) - (gdimy * gdimx)) * 4 < gdimy * gdimx) { - gdimy = new_gdimy; + ((new_grdim * gidim) - (grdim * gidim)) * 4 < grdim * gidim) { + grdim = new_grdim; } } // Cannot unroll with cross grid reductions - if (gdimy > 1 && iter_unroll_factor > 1) { + if (grdim > 1 && iter_unroll_factor > 1) { // Readjust the thread bindings, ideally we would repeat the block setup // without considering iter domain unrolling, but for now will simplify bdimx = std::min(max_threads_in_block, bdimx * iter_unroll_factor); @@ -664,10 +668,18 @@ ReductionParams OuterReductionHeuristic( iter_unroll_factor = 1; } + int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimy = LaunchParams::UNINITIALIZED_VAL; + ReductionParams rparams; // cross grid implies cross block - rparams.cross_block_inner_reduce = bdimy > 1 || gdimy > 1; - rparams.cross_grid_inner_reduce = gdimy > 1; + rparams.cross_block_inner_reduction = bdimy > 1 || grdim > 1; + rparams.cross_grid_inner_reduction = grdim > 1; + if (rparams.cross_grid_inner_reduction) { + rparams.split_grid_dim_inner_reduction = true; + rparams.grid_dim_inner_reduction = ParallelType::BIDy; + gdimy = std::min(grdim, scheduler_utils::y_grid_limit); + } rparams.multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1; if (rparams.multiple_reds_per_blk) { @@ -675,15 +687,12 @@ ReductionParams OuterReductionHeuristic( } rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; - - if (rparams.cross_grid_inner_reduce) { - rparams.grid_dim_inner_reduction = ParallelType::BIDy; - rparams.split_grid_dim_inner_reduction = - gdimy > scheduler_utils::y_grid_limit; + if (gidim > scheduler_utils::x_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimx = scheduler_utils::x_grid_limit; } - if (rparams.cross_block_inner_reduce) { + if (rparams.cross_block_inner_reduction) { if (rparams.block_dim_iter_dom == ParallelType::TIDx) { rparams.block_dim_inner_reduction = ParallelType::TIDy; } else { @@ -702,7 +711,7 @@ ReductionParams OuterReductionHeuristic( } rparams.lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, + gdimx, gdimy, LaunchParams::UNINITIALIZED_VAL, rparams.multiple_reds_per_blk ? bdimx : bdimy, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h index aafae3f09ff..a710e0c0ed8 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h @@ -31,9 +31,9 @@ class ReductionParams { // Inner Reduction Domain: // Reduce across the block? - bool cross_block_inner_reduce = false; + bool cross_block_inner_reduction = false; // Reduce across the grid? - bool cross_grid_inner_reduce = false; + bool cross_grid_inner_reduction = false; // Inner reduction unroll/vectorize bool unroll_inner_reduction = false; // Unrolling factor @@ -81,9 +81,9 @@ class ReductionParams { // Outer Reduction Domain if 3D Scheduled: // Reduce across the block? - bool cross_block_outer_reduce = false; + bool cross_block_outer_reduction = false; // Reduce across the grid? - bool cross_grid_outer_reduce = false; + bool cross_grid_outer_reduction = false; // Split grid dim for iteration axis in case it's too large for cuda bool split_grid_dim_outer_reduction = false; // Register persistent buffer size in outer dimension @@ -113,8 +113,8 @@ class ReductionParams { other.persistent_kernel == persistent_kernel && other.project_persistent_buffers == project_persistent_buffers && other.schedule_3D == schedule_3D && - other.cross_block_inner_reduce == cross_block_inner_reduce && - other.cross_grid_inner_reduce == cross_grid_inner_reduce && + other.cross_block_inner_reduction == cross_block_inner_reduction && + other.cross_grid_inner_reduction == cross_grid_inner_reduction && other.unroll_inner_reduction == unroll_inner_reduction && other.unroll_factor_inner_reduction == unroll_factor_inner_reduction && other.vectorize_inner_reduction == vectorize_inner_reduction && @@ -128,8 +128,8 @@ class ReductionParams { other.unroll_factor_iter_dom == unroll_factor_iter_dom && other.vectorize_iter_dom == vectorize_iter_dom && other.split_grid_dim_iter_dom == split_grid_dim_iter_dom && - other.cross_block_outer_reduce == cross_block_outer_reduce && - other.cross_grid_outer_reduce == cross_grid_outer_reduce && + other.cross_block_outer_reduction == cross_block_outer_reduction && + other.cross_grid_outer_reduction == cross_grid_outer_reduction && other.unroll_outer_reduction == unroll_outer_reduction && other.unroll_factor_outer_reduction == unroll_factor_outer_reduction && other.split_grid_dim_outer_reduction == @@ -153,10 +153,10 @@ class ReductionParams { if (schedule_3D) { ss << "3D Schedule\n" << "Outer Reduction: "; - if (cross_block_outer_reduce) { + if (cross_block_outer_reduction) { ss << "cross block - " << block_dim_outer_reduction << " / "; } - if (cross_grid_outer_reduce) { + if (cross_grid_outer_reduction) { ss << "cross grid - " << grid_dim_outer_reduction << " / "; ss << (split_grid_dim_outer_reduction ? "split grid dim / " : ""); } @@ -189,18 +189,18 @@ class ReductionParams { ss << "\nInner Reduction Domain: "; - if (cross_block_inner_reduce) { + if (cross_block_inner_reduction) { ss << "cross block - " << block_dim_inner_reduction << " / "; ss << (pad_inner_reduction_to_warp ? " pad to warp / " : ""); } - if (cross_grid_inner_reduce) { + if (cross_grid_inner_reduction) { ss << "cross grid - " << grid_dim_inner_reduction << " / "; ss << (split_grid_dim_inner_reduction ? "split grid dim / " : ""); } if (batches_per_block_inner_reduction > 1 || persistent_kernel) { ss << "persistent batch - " << batches_per_block_inner_reduction << " / "; } - ss << (cross_grid_inner_reduce && split_grid_dim_inner_reduction + ss << (cross_grid_inner_reduction && split_grid_dim_inner_reduction ? "split grid dimension / " : "") << (vectorize_inner_reduction ? "vectorize / " : "") @@ -225,8 +225,8 @@ class ReductionParamsHash { static_cast(rp.persistent_kernel) << (bits - 2) ^ static_cast(rp.project_persistent_buffers) << (bits - 3) ^ static_cast(rp.schedule_3D) << (bits - 4) ^ - static_cast(rp.cross_block_inner_reduce) << (bits - 5) ^ - static_cast(rp.cross_grid_inner_reduce) << (bits - 6) ^ + static_cast(rp.cross_block_inner_reduction) << (bits - 5) ^ + static_cast(rp.cross_grid_inner_reduction) << (bits - 6) ^ static_cast(rp.unroll_inner_reduction) << (bits - 7) ^ static_cast(rp.unroll_factor_inner_reduction) ^ static_cast(rp.vectorize_inner_reduction) << (bits - 8) ^ @@ -239,8 +239,8 @@ class ReductionParamsHash { static_cast(rp.unroll_factor_iter_dom) ^ static_cast(rp.vectorize_iter_dom) << (bits - 14) ^ static_cast(rp.split_grid_dim_iter_dom) << (bits - 15) ^ - static_cast(rp.cross_block_outer_reduce) << (bits - 16) ^ - static_cast(rp.cross_grid_outer_reduce) << (bits - 17) ^ + static_cast(rp.cross_block_outer_reduction) << (bits - 16) ^ + static_cast(rp.cross_grid_outer_reduction) << (bits - 17) ^ static_cast(rp.split_grid_dim_outer_reduction) << (bits - 18) ^ static_cast(rp.batches_per_block_outer_reduction) << (bits - 19); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index 3850fa9638b..57988d8d994 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -43,257 +43,170 @@ TensorView* scheduleReductionTV( !(!rparams.fastest_dim && rparams.vectorize_inner_reduction), "Cannot vectorize reduction domain on outer reductions."); - TORCH_INTERNAL_ASSERT( - !(rparams.cross_grid_inner_reduce && rparams.persistent_kernel), - "Grid reductions not implemented yet for persistent kernels."); - TORCH_INTERNAL_ASSERT( !(rparams.multiple_reds_per_blk && !has_iter_axis), "Multiple reductions requires an iter domain, but one wasn't found."); TORCH_INTERNAL_ASSERT( - !(rparams.cross_grid_inner_reduce && rparams.unroll_iter_dom), + !(rparams.cross_grid_inner_reduction && rparams.unroll_iter_dom), "Unrolling on iter domain not supported with cross grid reductions."); TORCH_INTERNAL_ASSERT( !(rparams.unroll_iter_dom && !has_iter_axis), "Unrolling on iter domain requires an iter domain."); - // Inner reduction axis: - if (rparams.unroll_inner_reduction) { - if (rparams.persistent_kernel) { - if (rparams.vectorize_inner_reduction) { - reduction_tv->split( - inner_reduce_axis, - rparams.batches_per_block_inner_reduction, - false); - reduction_tv->split( - inner_reduce_axis + 1, rparams.unroll_factor_inner_reduction); - - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(ParallelType::Vectorize); - } else { - reduction_tv->split( - inner_reduce_axis, - rparams.batches_per_block_inner_reduction * - rparams.unroll_factor_inner_reduction, - false); - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 2)->padToMultipleOfWarp(); - } - } - } else { - if (isParallelTypeThread(rparams.block_dim_inner_reduction)) { - if (rparams.vectorize_inner_reduction) { - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_inner_reduction)); - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(ParallelType::Vectorize); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } - } else { - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_inner_reduction)); - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(rparams.block_dim_inner_reduction); - - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 2)->padToMultipleOfWarp(); - } - } - } else { - // Inner reduction is not parallelized, but is unrolled or vectorized: - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize( - rparams.vectorize_inner_reduction ? ParallelType::Vectorize - : ParallelType::Unroll); - } + auto vectorize = [&reduction_tv](int axis, int factor) { + reduction_tv->split(axis, factor); + reduction_tv->axis(axis + 1)->parallelize(ParallelType::Vectorize); + }; + + auto inner_parallel = [&reduction_tv](int axis, ParallelType ptype) { + reduction_tv->split(axis, NamedScalar::getParallelDim(ptype)); + reduction_tv->axis(axis + 1)->parallelize(ptype); + }; + + auto inner_unswitch = [&reduction_tv](int axis) { + reduction_tv->split(axis, 1); + reduction_tv->axis(axis + 1)->parallelize(ParallelType::Unswitch); + }; + + auto inner_unroll = [&reduction_tv](int axis, int factor) { + reduction_tv->split(axis, factor); + reduction_tv->axis(axis + 1)->parallelize(ParallelType::Unroll); + }; + + auto outer_parallel = [&reduction_tv](int axis, ParallelType ptype) { + reduction_tv->split(axis, NamedScalar::getParallelDim(ptype), false); + reduction_tv->axis(axis)->parallelize(ptype); + }; + + auto outer_unswitch = [&reduction_tv](int axis) { + reduction_tv->split(axis, 1, false); + reduction_tv->axis(axis)->parallelize(ParallelType::Unswitch); + }; + + auto outer_unroll = [&reduction_tv](int axis, int factor) { + reduction_tv->split(axis, factor, false); + reduction_tv->axis(axis)->parallelize(ParallelType::Unroll); + }; + + if (rparams.persistent_kernel) { + // Persistent Format: + // [Grid Split, persistent buffer, unswitch, unroll, thread dim, vectorize] + if (rparams.vectorize_inner_reduction) { + vectorize(inner_reduce_axis, rparams.unroll_factor_inner_reduction); + } + auto outer_i = inner_reduce_axis; + if (rparams.cross_grid_inner_reduction) { + outer_parallel(outer_i++, rparams.grid_dim_inner_reduction); + } + + reduction_tv->split( + outer_i++, rparams.batches_per_block_inner_reduction, false); + + outer_unswitch(outer_i++); + + if (!rparams.vectorize_inner_reduction && rparams.unroll_inner_reduction) { + outer_unroll(outer_i++, rparams.unroll_factor_inner_reduction); + } + + reduction_tv->axis(outer_i)->parallelize(rparams.block_dim_inner_reduction); + + if (rparams.pad_inner_reduction_to_warp) { + reduction_tv->axis(outer_i)->padToMultipleOfWarp(); } - // Unswitch axis which gives us finer control on allocations with - // unrolling - reduction_tv->split(inner_reduce_axis, 1); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unswitch); } else { - // Parallelize reduction axis, don't unroll it0 - if (rparams.cross_block_inner_reduce) { - if (rparams.persistent_kernel) { - reduction_tv->split( - inner_reduce_axis, - rparams.batches_per_block_inner_reduction, - false); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } - } else { - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_inner_reduction)); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } + // Non-persistent format: + // [Grid Split, Remainder, unswitch, unroll, thread dim, vectorize] + if (rparams.vectorize_inner_reduction) { + vectorize(inner_reduce_axis, rparams.unroll_factor_inner_reduction); + } + + if (rparams.cross_block_inner_reduction) { + inner_parallel(inner_reduce_axis, rparams.block_dim_inner_reduction); + if (rparams.pad_inner_reduction_to_warp) { + reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); } - } else { - // No parallelization on reduction dim, fake an unswitch axis for - // rfactor - reduction_tv->split(inner_reduce_axis, 1); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unswitch); } - } - if (rparams.cross_grid_inner_reduce) { - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.grid_dim_inner_reduction), - false); - reduction_tv->axis(inner_reduce_axis) - ->parallelize(rparams.grid_dim_inner_reduction); + if (!rparams.vectorize_inner_reduction && rparams.unroll_inner_reduction) { + inner_unroll(inner_reduce_axis, rparams.unroll_factor_inner_reduction); + } + + inner_unswitch(inner_reduce_axis); + if (rparams.cross_grid_inner_reduction) { + if (rparams.split_grid_dim_inner_reduction) { + outer_parallel(inner_reduce_axis, rparams.grid_dim_inner_reduction); + } else { + reduction_tv->axis(inner_reduce_axis) + ->parallelize(rparams.grid_dim_inner_reduction); + } + } } // Outer reduction axis if (rparams.schedule_3D) { - if (rparams.unroll_outer_reduction) { - if (rparams.persistent_kernel) { - reduction_tv->split( - outer_reduce_axis, - rparams.batches_per_block_outer_reduction * - rparams.unroll_factor_outer_reduction, - false); - reduction_tv->split( - outer_reduce_axis, rparams.unroll_factor_outer_reduction); - - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(outer_reduce_axis + 2) - ->parallelize(rparams.block_dim_outer_reduction); - } else { - if (isParallelTypeThread(rparams.block_dim_outer_reduction)) { - reduction_tv->split( - outer_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_outer_reduction)); - reduction_tv->split( - outer_reduce_axis, rparams.unroll_factor_outer_reduction); - - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(outer_reduce_axis + 2) - ->parallelize(rparams.block_dim_outer_reduction); + if (rparams.persistent_kernel) { + // Persistent Format: + // [Grid Split, persistent buffer, unroll, thread dim] + auto outer_i = outer_reduce_axis; + if (rparams.cross_grid_outer_reduction) { + outer_parallel(outer_i++, rparams.grid_dim_outer_reduction); + } - } else { - // outer reduction is not parallelized, but is unrolled or vectorized: - reduction_tv->split( - outer_reduce_axis, rparams.unroll_factor_outer_reduction); - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - } + reduction_tv->split( + outer_i++, rparams.batches_per_block_outer_reduction, false); + + if (rparams.unroll_outer_reduction) { + outer_unroll(outer_i++, rparams.unroll_factor_outer_reduction); } + + reduction_tv->axis(outer_i)->parallelize( + rparams.block_dim_outer_reduction); } else { - // Parallelize reduction axis, don't unroll it0 - if (rparams.cross_block_outer_reduce) { - if (rparams.persistent_kernel) { - reduction_tv->split( - outer_reduce_axis, - rparams.batches_per_block_outer_reduction, - false); - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(rparams.block_dim_outer_reduction); - } else { - reduction_tv->split( - outer_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_outer_reduction)); - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(rparams.block_dim_outer_reduction); - } + // Non-persistent format: + // [Grid Split, Remainder, unroll, thread dim] + if (rparams.cross_block_outer_reduction) { + inner_parallel(outer_reduce_axis, rparams.block_dim_outer_reduction); } - } - if (rparams.cross_grid_outer_reduce) { - reduction_tv->split( - outer_reduce_axis, - NamedScalar::getParallelDim(rparams.grid_dim_outer_reduction), - false); - reduction_tv->axis(outer_reduce_axis) - ->parallelize(rparams.grid_dim_outer_reduction); + if (rparams.unroll_outer_reduction) { + inner_unroll(outer_reduce_axis, rparams.unroll_factor_outer_reduction); + } + + if (rparams.cross_grid_outer_reduction) { + outer_parallel(outer_reduce_axis, rparams.grid_dim_outer_reduction); + } } } // Iteration domain if (has_iter_axis) { + // [Grid Split, unswitch, unroll, thread dim, vectorize] + + if (rparams.vectorize_iter_dom) { + vectorize(iter_axis, rparams.unroll_factor_iter_dom); + } + if (isParallelTypeThread(rparams.block_dim_iter_dom)) { - if (rparams.vectorize_iter_dom) { - reduction_tv->split(iter_axis, rparams.unroll_factor_iter_dom); - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Vectorize); - - reduction_tv->split( - iter_axis, NamedScalar::getParallelDim(rparams.block_dim_iter_dom)); - reduction_tv->axis(iter_axis + 1) - ->parallelize(rparams.block_dim_iter_dom); - } else { - if ((rparams.fastest_dim && rparams.multiple_reds_per_blk) || - !rparams.fastest_dim) { - reduction_tv->split( - iter_axis, - NamedScalar::getParallelDim(rparams.block_dim_iter_dom)); - reduction_tv->axis(iter_axis + 1) - ->parallelize(rparams.block_dim_iter_dom); - } - if (rparams.unroll_iter_dom) { - reduction_tv->split(iter_axis, rparams.unroll_factor_iter_dom); - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Unroll); - } - } - } else if (rparams.unroll_iter_dom) { - // Iteration domain is not parallelized but it is unrolled or vectorized - reduction_tv->split(iter_axis, rparams.unroll_factor_iter_dom); - if (rparams.vectorize_iter_dom) { - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Vectorize); - } else { - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Unroll); - } + inner_parallel(iter_axis, rparams.block_dim_iter_dom); } + + if (!rparams.vectorize_iter_dom && rparams.unroll_iter_dom) { + inner_unroll(iter_axis, rparams.unroll_factor_iter_dom); + } + if (rparams.unroll_iter_dom) { - reduction_tv->split(iter_axis, 1); - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Unswitch); + inner_unswitch(iter_axis); } - if (rparams.fastest_dim && rparams.split_grid_dim_iter_dom) { - reduction_tv->split(iter_axis, scheduler_utils::x_grid_limit); - reduction_tv->axis(iter_axis + 1)->parallelize(rparams.grid_dim_iter_dom); - } else { - reduction_tv->axis(iter_axis)->parallelize(rparams.grid_dim_iter_dom); + if (isParallelTypeThread(rparams.grid_dim_iter_dom)) { + if (rparams.split_grid_dim_iter_dom) { + outer_parallel(iter_axis, rparams.grid_dim_iter_dom); + } else { + reduction_tv->axis(iter_axis)->parallelize(rparams.grid_dim_iter_dom); + } } } @@ -563,6 +476,48 @@ void multiReductionInliner( scheduler_utils::computeWithOutputs( red_tv, pos, ComputeAtMode::BestEffort); } + // For topologies where there may not be paths to all inputs/outputs from + // the reductions, we need to take a similar approach to the unrolled + // version and setup of compute at from inputs->outputs that are not + // inputs/outputs of the reductions. + std::vector compute_to; + std::unordered_set outs_of_reds; + { + auto outs_of_red_vec = ir_utils::outputTvsOf(ref_tvs); + outs_of_reds = std::unordered_set( + outs_of_red_vec.begin(), outs_of_red_vec.end()); + } + for (auto out : ir_utils::filterByType(fusion->outputs())) { + // only terminating outputs + if (out->uses().size()) { + continue; + } + if (outs_of_reds.find(out) != outs_of_reds.end()) { + continue; + } + compute_to.push_back(out); + } + + std::vector compute_from; + std::unordered_set inps_of_reds; + { + auto inps_of_red_vec = ir_utils::inputTvsOf(ref_tvs); + inps_of_reds = std::unordered_set( + inps_of_red_vec.begin(), inps_of_red_vec.end()); + } + for (auto inp : ir_utils::filterByType(fusion->inputs())) { + if (inps_of_reds.find(inp) != inps_of_reds.end()) { + continue; + } + compute_from.push_back(inp); + } + + scheduler_utils::computeAtBetween( + compute_from, + compute_to, + -1, + ComputeAtMode::MostInlined, + mapped_to_trivial_reduction); } } @@ -595,12 +550,6 @@ int idPos(const IterDomain* id) { } inner_most--; - // Reduction and block - if (id->isReduction() && id->isBlockDim()) { - return inner_most; - } - inner_most--; - // Reduction and constant if (id->isReduction() && id->extent()->isConstScalar()) { return inner_most; @@ -614,7 +563,7 @@ int idPos(const IterDomain* id) { inner_most--; // Reduction and thread - if (id->isReduction() && id->isThreadDim()) { + if (id->isReduction() && id->isThread()) { return inner_most; } inner_most--; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 46b574ac6af..4f2982b01f2 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -1,10 +1,12 @@ #include +#include #include #include #include #include #include #include +#include #include #include @@ -38,7 +40,8 @@ class SchedulerTopologyChecker { auto all_vals = fusion->usedMathVals(); std::vector reduction_tvs; for (auto tv : ir_utils::filterByType(all_vals)) { - if (tv->hasReduction() && !fusion->hasInput(tv)) { + if (tv->hasReduction() && + !(fusion == tv->fusion() && tv->isFusionInput())) { reduction_tvs.push_back(tv); } } @@ -355,6 +358,50 @@ class SchedulerTopologyChecker { return true; } }; + +bool isConnectedFusionGraph(Fusion* fusion) { + if (fusion->outputs().empty()) { + // Trivial case interpreted as connected + return true; + } + + // A set of connected components on the fusion graph + DisjointSet component_sets; + + // Iterate through all used exprs + for (auto expr : fusion->exprs()) { + TORCH_INTERNAL_ASSERT( + !expr->inputs().empty(), "unknown expr with zero input"); + + // Each expr joins all its inputs and + // outputs to the same component + auto input0 = expr->inputs()[0]; + for (auto input : expr->inputs()) { + component_sets.join(input0, input); + } + for (auto output : expr->outputs()) { + component_sets.join(input0, output); + } + } + + // Join aliased outputs + for (auto alias_it : fusion->ioAlias()) { + component_sets.join(alias_it.first, alias_it.second); + } + + // Check connected-ness: + // If there is no independent compute flow + // on this fusion graph, all outputs will be + // equivalent/connected to the first output. + auto output0 = fusion->outputs()[0]; + for (auto output : fusion->outputs()) { + if (!component_sets.areEquivalent(output0, output)) { + return false; + } + } + return true; +} + } // namespace SchedulerRuntimeInfo::SchedulerRuntimeInfo( @@ -634,39 +681,10 @@ bool SchedulerEntry::sameAs(const SchedulerEntry* other) { } namespace { -template -inline bool isTrivialReduction(REDUCTION_OP* red) { - auto o_tv = red->out()->template as(); - // Assuming graph unscheduled at this point. - for (auto id : o_tv->getRootDomain()) { - if (id->isReduction() && !id->extent()->isOneInt()) { - return false; - } - } - return true; -} - -template -std::vector findReductionOps(Fusion* fusion) { - std::vector red_ops; - for (auto expr : fusion->exprs()) { - if (auto red = dynamic_cast(expr)) { - if (!isTrivialReduction(red)) { - red_ops.push_back(red); - } - } - } - return red_ops; -} - std::vector findTransposeOps(Fusion* fusion) { - std::vector transpose_ops; - for (auto expr : fusion->exprs()) { - if (auto transpose_op = dynamic_cast(expr)) { - transpose_ops.push_back(transpose_op); - } - } - return transpose_ops; + auto exprs = fusion->exprs(); + auto transpose_ops = ir_utils::filterByType(exprs); + return std::vector(transpose_ops.begin(), transpose_ops.end()); } static bool checkPatternEquivalence( @@ -765,9 +783,8 @@ class ReductionScheduler : public SchedulerEntry { } // Make sure reduction axes are consistent through the fusion - if (findReductionOps(fusion).size() + - findReductionOps(fusion).size() > - 1) { + auto reduction_ops = ir_utils::getReductionOps(fusion); + if (reduction_ops.size() > 1) { // Before examining the reduction axes want to quickly // check the reductions have the same axis width // to avoid building root domain map in easier cases @@ -857,9 +874,16 @@ class PointWiseScheduler : public SchedulerEntry { } static bool canScheduleCompileTime(Fusion* fusion) { - auto red_ops = findReductionOps(fusion); - auto welford_ops = findReductionOps(fusion); - return red_ops.empty() && welford_ops.empty(); + // Currently using the same path as the scheduler + // to eliminate mismatch between canSchedule and + // schedule pointwise. + if (!hasReferenceTensorView(fusion)) { + return false; + } + + auto reduction_ops = ir_utils::getReductionOps(fusion); + auto welford_ops = ir_utils::filterByType(reduction_ops); + return reduction_ops.empty() && welford_ops.empty(); } static bool canScheduleRunTime( @@ -900,6 +924,14 @@ class PersistentKernelScheduler : public SchedulerEntry { } static bool canScheduleCompileTime(Fusion* fusion) { + auto reduction_ops = ir_utils::getReductionOps(fusion); + auto welford_ops = ir_utils::filterByType(reduction_ops); + // For persistent schedule we want welford translated to average and + // standard deviation reductions. + if (welford_ops.begin() != welford_ops.end()) { + return false; + } + auto view_tvs = scheduler_utils::getViewTVs(fusion); if (view_tvs.size() > 0) { return false; @@ -1079,8 +1111,13 @@ bool checkCanSchedule( // since for all current use cases // it has to pass all the compile time checks to create a data cache for this // fusion. - if (!data_cache && !SchedulerType::canScheduleCompileTime(fusion)) { - return false; + if (!data_cache) { + if (!isConnectedFusionGraph(fusion)) { + return false; + } + if (!SchedulerType::canScheduleCompileTime(fusion)) { + return false; + } } return SchedulerType::canScheduleRunTime(fusion, runtime_info, data_cache); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 7ce9addf0cb..90b348236cf 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -287,6 +287,15 @@ class PersistentBufferResolution : public IterVisitor { } if (tv->hasReduction()) { + if (std::any_of( + resolution_points_.begin(), + resolution_points_.end(), + [&tv](TensorView* resolution_point) { + return DependencyCheck::isDependencyOf(resolution_point, tv); + })) { + // If already resolved, don't start a new reduction path. + return; + } on_reduction_path_.emplace(tv); } } @@ -587,7 +596,7 @@ void computeAtBetween( return mapped_to_trivial_reduction.count(id); }); - pos = pos_it == consumer->domain()->domain().end() + auto consumer_pos = pos_it == consumer->domain()->domain().end() ? pos : std::min( (int)std::distance( @@ -596,7 +605,7 @@ void computeAtBetween( (pos < 0 ? pos + (int)consumer->nDims() : pos)); // Assume we don't want to reset computeAt on tensors that have already // performed it. - producer->computeAt(consumer, pos, mode); + producer->computeAt(consumer, consumer_pos, mode); } } } @@ -1038,15 +1047,22 @@ std::vector> cacheAndForkOutputs( } namespace { +// If this is an rfactored reduction domain, actually check the root domain, +// this is because the rfactored reduction tensorview has the vectorized +// dimension, but that means the rfactor domain could have reordered what we +// consider the "inner most" allocated position on it if we consider the rfactor +// dimension. IterDomain* innerMostRootDim(TensorView* tv) { if (tv->nDims() == 0) { return nullptr; } IterDomain* inner_most_id = nullptr; - for (auto it = tv->getMaybeRFactorDomain().rbegin(); - it != tv->getMaybeRFactorDomain().rend(); - it++) { + auto root_domain = tv->hasReduction() && tv->hasRFactor() + ? tv->getRootDomain() + : tv->getMaybeRFactorDomain(); + + for (auto it = root_domain.rbegin(); it != root_domain.rend(); it++) { if ((*it)->isReduction() && tv->isFusionInput()) { continue; } @@ -1084,7 +1100,7 @@ IterDomain* projectIdToRoot( return reference_id; } - auto replay_exprs = ExprSort::getExprs(tv->fusion(), {reference_id}); + auto replay_exprs = StmtSort::getExprs(tv->fusion(), {reference_id}, false); if (replay_exprs.empty()) { return reference_id; } @@ -1193,12 +1209,16 @@ std::unordered_set FindAllMappedDims::from( TensorView* tv, IterDomain* id, bool vectorize_pass) { + auto root_domain = tv->hasReduction() && tv->hasRFactor() + ? tv->getRootDomain() + : tv->getMaybeRFactorDomain(); + TORCH_INTERNAL_ASSERT( std::find_if( - tv->getMaybeRFactorDomain().begin(), - tv->getMaybeRFactorDomain().end(), + root_domain.begin(), + root_domain.end(), [&id](IterDomain* root_id) { return root_id == id; }) != - tv->getMaybeRFactorDomain().end(), + root_domain.end(), "Tried to map out ", id, " from TV ", diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 2bf8967f74e..911bda3da04 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -3,10 +3,13 @@ #include #include #include +#include #include #include #include #include +#include +#include // Cleanup #include @@ -24,8 +27,14 @@ DataType aten_opt_type_map(const c10::optional& scalar_type) { } } // namespace -TensorView::TensorView(TensorDomain* domain, DataType dtype, MemoryType mtype) - : Val(ValType::TensorView, dtype), domain_(domain), memory_type_(mtype) { +TensorView::TensorView( + IrBuilderPasskey passkey, + TensorDomain* domain, + DataType dtype, + MemoryType mtype) + : Val(passkey, ValType::TensorView, dtype), + domain_(domain), + memory_type_(mtype) { // Don't do this after transforms if (domain_->domain() == domain_->getRootDomain()) { // Mark the size-1 axes as broadcast to support implicit broadcast semantic @@ -38,10 +47,15 @@ TensorView::TensorView(TensorDomain* domain, DataType dtype, MemoryType mtype) } } -TensorView::TensorView(const std::shared_ptr& tensor_type) - : Val(ValType::TensorView, - aten_opt_type_map(tensor_type->scalarType()), - false) { +TensorView::TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& tensor_type) + : Val(passkey, + ValType::TensorView, + aten_opt_type_map(tensor_type->scalarType())) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); std::vector sizes; TORCH_CHECK( @@ -51,13 +65,14 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) if (tensor_type->sizes()[i].has_value() && tensor_type->sizes()[i].value() == 1) { // If size is known to be 1, assuem it needs to be broadcasted. - sizes.push_back(new IterDomain( - new Int(0), - new Int(1), + sizes.push_back(IrBuilder::create( + passkey.ir_container_->zeroVal(), + passkey.ir_container_->oneVal(), ParallelType::Serial, IterType::BroadcastWithStride)); } else { - sizes.push_back(new IterDomain(new Int(0), new Int())); + sizes.push_back(IrBuilder::create( + passkey.ir_container_->zeroVal(), IrBuilder::create())); } } @@ -92,8 +107,16 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) } } - domain_ = new TensorDomain(sizes, contig_info); - name_ = fusion_->registerVal(this); + domain_ = IrBuilder::create(sizes, contig_info); +} + +TensorView::TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& jit_value) + : TensorView(passkey, jit_value->type()->cast()) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); } TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) @@ -102,7 +125,9 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) compute_at_pos_(src->compute_at_pos_), max_producer_pos_(src->max_producer_pos_), memory_type_(src->memory_type_), - swizzle_type_(src->swizzle_type_) { + swizzle_type_(src->swizzle_type_), + is_double_buffered_(src->is_double_buffered_), + cpu_scalar_(src->cpu_scalar_) { for (const auto id : src->axesToSwizzle()) { axes_to_swizzle_.push_back(ir_cloner->clone(id)); } @@ -152,6 +177,18 @@ std::vector::size_type TensorView::nDims() const { return domain()->nDims(); } +// sets cpu_scalar_ value, which is special handling for CPU based zero-dim +// tensors (i.e. CPU Tensors that only have one value). This is only used if +// on an input value, otherwise ignored. This is important as special handling +// because these "scalars" should be type promoted as a tensor, but we want to +// avoid explicit copying of the data, so we want to pass the data value as a +// standard kernel argument value. +void TensorView::setCpuScalar(bool is_cpu_scalar) { + TORCH_INTERNAL_ASSERT( + nDims() == 0, "Only 0-dim tensors can be marked as a cpu scalar."); + cpu_scalar_ = is_cpu_scalar; +} + IterDomain* TensorView::axis(int pos) const { TORCH_INTERNAL_ASSERT( nDims() > 0, "Tried to access an axis in a 0-dim TensorView"); @@ -167,6 +204,9 @@ IterDomain* TensorView::axis(int pos) const { } void TensorView::setComputeAt(unsigned int pos, bool decrease) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); if (pos <= compute_at_pos_ && !decrease) { return; } @@ -182,6 +222,9 @@ void TensorView::setComputeAt(unsigned int pos, bool decrease) { } void TensorView::setMaxProducer(unsigned int pos, bool decrease) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); if (pos <= max_producer_pos_ && !decrease) { return; } @@ -200,6 +243,9 @@ TensorView* TensorView::computeAt( TensorView* consumer, int position, ComputeAtMode mode) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // Make sure this and consumer are not the same tensor, that's illegal TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)"); @@ -228,6 +274,9 @@ TensorView* TensorView::computeWith( TensorView* consumer, int position, ComputeAtMode mode) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // Make sure this and consumer are not the same tensor, that's illegal TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)"); @@ -290,7 +339,7 @@ TensorView* TensorView::split( unsigned int factor, bool inner_split, bool trim_out_of_bounds) { - split(axis, new Int(factor), inner_split, trim_out_of_bounds); + split(axis, IrBuilder::create(factor), inner_split, trim_out_of_bounds); return this; } @@ -336,6 +385,9 @@ TensorView* TensorView::merge(int axis_o, int axis_i) { } TensorView* TensorView::reorder(const std::unordered_map& old2new_) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); TORCH_INTERNAL_ASSERT( !(nDims() == 0 && old2new_.size() > 0), "Tried to reorder a 0-dim TensorView"); @@ -383,6 +435,9 @@ TensorView* TensorView::reorder(const std::unordered_map& old2new_) { TensorView* TensorView::swizzle( SwizzleType type, const std::vector& axes) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); swizzle_type_ = type; // Clear previously set swizzle axes if any @@ -432,6 +487,9 @@ TensorView* TensorView::swizzle( } TensorView* TensorView::rFactor(const std::vector& axes) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // TODO: I think we should do this but // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the // moment. @@ -462,7 +520,8 @@ TensorView* TensorView::rFactor(const std::vector& axes) { auto consumer_domain = domain_pair.second; // This domain will be the consumer, so create the producer - TensorView* producer = new TensorView(producer_domain, getDataType().value()); + TensorView* producer = + IrBuilder::create(producer_domain, getDataType().value()); // Set domain of consumer setDomain(consumer_domain); @@ -470,14 +529,14 @@ TensorView* TensorView::rFactor(const std::vector& axes) { // Setup dependency chain, inserting producer before this op. // Expr* producer_definition = - new ReductionOp( + IrBuilder::create( this_definition->getReductionOpType(), this_definition->init(), producer, this_definition->in()); // Expr* consumer_definition = - new ReductionOp( + IrBuilder::create( this_definition->getReductionOpType(), this_definition->init(), consumer, @@ -489,6 +548,9 @@ TensorView* TensorView::rFactor(const std::vector& axes) { TensorView* TensorView::welfordRfactorHelper( TensorView* tv, const std::vector& axes) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // Hack: // Semantically we should always keep the outputs of welfordOp scheduled // the same but the user end cannot guarantee that. @@ -520,7 +582,8 @@ TensorView* TensorView::welfordRfactorHelper( std::vector new_contig( tv->domain()->contiguity().begin(), tv->domain()->contiguity().end()); // replace tensor domain of target tv - tv->setDomain(new TensorDomain(tv->getRootDomain(), new_id, new_contig)); + tv->setDomain(IrBuilder::create( + tv->getRootDomain(), new_id, new_contig)); } // Split tensor view into 2 parts @@ -532,7 +595,7 @@ TensorView* TensorView::welfordRfactorHelper( // This domain will be the consumer, so create the producer TensorView* producer = - new TensorView(producer_domain, tv->getDataType().value()); + IrBuilder::create(producer_domain, tv->getDataType().value()); // Set domain of consumer tv->setDomain(consumer_domain); @@ -545,6 +608,9 @@ WelfordResult TensorView::rFactor( TensorView* avg, TensorView* var, TensorView* n) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView"); FusionGuard fg(fusion()); TORCH_CHECK( @@ -588,7 +654,7 @@ WelfordResult TensorView::rFactor( // Setup dependency chain, inserting producer before this op. // Expr* producer_definition = - new WelfordOp( + IrBuilder::create( producer_avg, producer_var, producer_n, /*out var/avg/count */ @@ -600,7 +666,7 @@ WelfordResult TensorView::rFactor( wop->inN()); // Expr* consumer_definition = - new WelfordOp( + IrBuilder::create( avg, var, n, @@ -615,6 +681,9 @@ WelfordResult TensorView::rFactor( } TensorView* TensorView::cache_before() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); FusionGuard fg(fusion()); TORCH_CHECK( @@ -652,8 +721,10 @@ TensorView* TensorView::cache_before() { // This domain will be the consumer which needs a new domain, so replace the // producers domain with this domain. - TensorView* producer = new TensorView( - new TensorDomain( + TensorView* producer = IrBuilder::create( + container(), + IrBuilder::create( + container(), domain()->getRootDomain(), domain()->getRFactorDomain(), domain()->domain(), @@ -671,8 +742,10 @@ TensorView* TensorView::cache_before() { new_root_domain[i++] = dom->clone(); } - consumer->setDomain(new TensorDomain( - new_root_domain, std::vector(new_root_domain.size(), true))); + consumer->setDomain(IrBuilder::create( + container(), + new_root_domain, + std::vector(new_root_domain.size(), true))); // Insert producer - Cache_Before (CB) - before this TV. // Before: Prev TV -> [Definition Op] -> This TV @@ -684,7 +757,7 @@ TensorView* TensorView::cache_before() { ir_utils::replaceValInExpr(definition(), this, producer); // Expr* producer_uses = - new UnaryOp(UnaryOpType::Set, consumer, producer); + IrBuilder::create(container(), UnaryOpType::Set, consumer, producer); // definition_ is no longer valid // setDefinition(nullptr); @@ -697,6 +770,9 @@ TensorView* TensorView::cache_before() { } TensorView* TensorView::cache_fork() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); FusionGuard fg(fusion()); // Before: [Expr] -> This TV (Global Output) -> [Usage Expr] @@ -704,7 +780,7 @@ TensorView* TensorView::cache_fork() { // (Fork) -> [Set Expr] -> New TV (Global Output) TORCH_CHECK( - fusion()->hasOutput(this) && !this->uses().empty(), + this->isFusionOutput() && !this->uses().empty(), "Error adding cache_fork ", this, " this TensorView must be an output with subsequent uses"); @@ -717,14 +793,16 @@ TensorView* TensorView::cache_fork() { // This domain will be the producer, so create the consumer auto root_domain = TensorDomain::noReductions(getMaybeRFactorDomain()); - TensorView* new_output = new TensorView( - new TensorDomain( + TensorView* new_output = IrBuilder::create( + container(), + IrBuilder::create( + container(), IterDomain::clone(root_domain), std::vector(root_domain.size(), true)), getDataType().value()); // Create write operation from this TV to new output - new UnaryOp(UnaryOpType::Set, new_output, this); + IrBuilder::create(container(), UnaryOpType::Set, new_output, this); // The new TV becomes an output. // New TV has global memory type. @@ -739,13 +817,14 @@ TensorView* TensorView::cache_fork() { } TensorView* TensorView::cache_after() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); FusionGuard fg(fusion()); - const bool kIsFusionInput = fusion()->hasInput(this); - // Get all the uses for this Tensorview TORCH_CHECK( - !fusion()->hasOutput(this), + !isFusionOutput(), "Error adding cache_after ", this, " we restrict using cache_after on an output."); @@ -759,7 +838,7 @@ TensorView* TensorView::cache_after() { // It also did additional transformation when this tensor is an // input and the outputs of its consumers have computeAt. Make sure // we no longer rely on that behavior. - if (kIsFusionInput) { + if (isFusionInput()) { for (const auto& expr : uses()) { for (TensorView* output : ir_utils::filterByType(expr->outputs())) { @@ -782,9 +861,12 @@ TensorView* TensorView::cache_after() { } // This domain will be the producer, so create the consumer - TensorView* consumer = new TensorView( - new TensorDomain( - new_root_domain, std::vector(new_root_domain.size(), true)), + TensorView* consumer = IrBuilder::create( + container(), + IrBuilder::create( + container(), + new_root_domain, + std::vector(new_root_domain.size(), true)), getDataType().value()); // Set domain of producer - No Change @@ -800,14 +882,14 @@ TensorView* TensorView::cache_after() { } // Expr* consumer_definition = - new UnaryOp(UnaryOpType::Set, consumer, producer); + IrBuilder::create(container(), UnaryOpType::Set, consumer, producer); return consumer; } void TensorView::setMemoryType(MemoryType mt) { memory_type_ = mt; - if (fusion()->hasInput(this) || fusion()->hasOutput(this)) { + if (isFusionInput() || isFusionOutput()) { TORCH_INTERNAL_ASSERT( mt == MemoryType::Global, "Tried to set an input or output to the fusion to a non-global memory type."); @@ -832,7 +914,23 @@ void TensorView::clearReductionIterDomains() { } } - setDomain(new TensorDomain(new_root, new_contig)); + setDomain(IrBuilder::create(container(), new_root, new_contig)); +} + +void TensorView::doubleBuffer() { + // Early correctness checking. May miss eventual errors as the + // checks depend on memory types and parallelization, which may not + // be finalized until lowering. + validateDoubleBufferedTensor(this); + is_double_buffered_ = true; +} + +bool TensorView::isEmptyTensor() const { + auto& root_domain = getMaybeRFactorDomain(); + return std::all_of( + root_domain.begin(), root_domain.end(), [](IterDomain* id) { + return id->extent()->isZeroInt(); + }); } TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) { @@ -872,7 +970,8 @@ TensorView* TensorViewBuilder::build() const { std::vector domain(ndims_, nullptr); for (const auto i : c10::irange(ndims_)) { if (shape_.empty() || shape_[i] == -1) { - domain[i] = new IterDomain(new Int(0), new Int()); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create()); } else { TORCH_CHECK( shape_[i] >= 0, @@ -880,19 +979,22 @@ TensorView* TensorViewBuilder::build() const { "For a tensor representing a single scalar use ndims = 0 with no sizes set."); if (shape_[i] == 1) { // If size is known to be 1, assume it needs to be broadcasted. - domain[i] = new IterDomain( - new Int(0), - new Int(1), + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), ParallelType::Serial, IterType::BroadcastWithStride); } else { - domain[i] = new IterDomain(new Int(0), new Int(shape_[i])); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create(shape_[i])); } } } // Create the final TensorView - return new TensorView(new TensorDomain(domain, contiguity_), dtype_); + return IrBuilder::create( + IrBuilder::create(domain, contiguity_), dtype_); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 54136616268..bae77943b33 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -228,7 +228,7 @@ BestEffortReplay::BestEffortReplay( } // Grab expr history of iter domains in target_domain - std::vector target_exprs = ExprSort::getExprs( + std::vector target_exprs = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector(target_domain.begin(), target_domain.end())); @@ -239,7 +239,7 @@ BestEffortReplay::BestEffortReplay( // replay_domain map. // Map replay domain's IterDomains to the Exprs they're used in - std::vector replay_exprs = ExprSort::getExprs( + std::vector replay_exprs = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector(replay_domain.begin(), replay_domain.end())); @@ -561,7 +561,7 @@ struct ConsumerForwardingInfo { auto consumer_bcast_ids_not_in_producer = consumer_bcast_roots_not_in_producer; - std::vector consumer_history = ExprSort::getExprs( + std::vector consumer_history = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector( consumer->domain()->domain().begin(), @@ -706,7 +706,7 @@ BestEffortReplay BestEffortReplay::replayCasP( } // Grab all exprs used to make the forwarded compliments - auto compliment_exprs = ExprSort::getExprs( + auto compliment_exprs = StmtSort::getExprs( FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); // Figure out if there are any leaves in compliment_exprs that aren't diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index cde502d636e..f1c4ae378b5 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d0d03532cd6..7ea96f74bf1 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -49,23 +50,26 @@ class ReplaySelf : public ReplayTransformations { // Manually replay the split, following the output of the operations. // This is so rfactor ops are replayed correctly. - IterDomain* ido = new IterDomain( - new Int(0), + IterDomain* ido = IrBuilder::create( + s->container(), + s->container()->zeroVal(), s->innerSplit() ? remainder->as() : s->factor(), s->outer()->getParallelType(), s->outer()->getIterType(), s->outer()->isRFactorProduct()); // inner IterDomain - IterDomain* idi = new IterDomain( - new Int(0), + IterDomain* idi = IrBuilder::create( + s->container(), + s->container()->zeroVal(), s->innerSplit() ? s->factor() : remainder->as(), s->inner()->getParallelType(), s->inner()->getIterType(), s->inner()->isRFactorProduct()); // Generate the split node - new Split( + IrBuilder::create( + s->container(), ido, idi, mapped, @@ -112,14 +116,16 @@ class ReplaySelf : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - IterDomain* merged_id = new IterDomain( - new Int(0), + IterDomain* merged_id = IrBuilder::create( + m->container(), + m->container()->zeroVal(), merged_id_size->as(), m->out()->getParallelType(), m->outer()->getIterType(), m->out()->isRFactorProduct()); - new Merge(merged_id, id_outer_mapped, id_inner_mapped); + IrBuilder::create( + m->container(), merged_id, id_outer_mapped, id_inner_mapped); // Remove inputs from the leaf IDs leaf_ids_.erase(id_outer_mapped); @@ -197,7 +203,8 @@ TensorDomain* TransformReplay::fullSelfReplay( "Error during replay, didn't replay an axis."); new_rfactor_domain[i++] = it->second; } - return new TensorDomain( + return IrBuilder::create( + self->container(), new_self_root->getRootDomain(), new_rfactor_domain, new_domain, @@ -205,8 +212,11 @@ TensorDomain* TransformReplay::fullSelfReplay( } } - return new TensorDomain( - new_self_root->getRootDomain(), new_domain, new_self_root->contiguity()); + return IrBuilder::create( + self->container(), + new_self_root->getRootDomain(), + new_domain, + new_self_root->contiguity()); } // Producer could have rfactor axes which consumer may want replayed. We can @@ -407,7 +417,8 @@ std::pair TransformReplay::replayPasC( new_IDs.push_back(id); } } - TensorDomain* replayed = new TensorDomain( + TensorDomain* replayed = IrBuilder::create( + producer->container(), producer->getRootDomain(), producer->getRFactorDomain(), new_IDs, @@ -604,7 +615,8 @@ std::pair TransformReplay::replayCasP( if (used_IDs.find(id) == used_IDs.end()) new_IDs.push_back(id); - TensorDomain* replayed = new TensorDomain( + TensorDomain* replayed = IrBuilder::create( + consumer->container(), consumer->getRootDomain(), consumer->getRFactorDomain(), new_IDs, diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 92898b54ba7..1fd3d110200 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -1,7 +1,7 @@ #pragma once +#include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index 8ac28cf3a2c..5939ffee289 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -52,23 +53,26 @@ class ReplayRFactor : public ReplayTransformations { // Manually replay the split, making reduction = false and rfactor = true // outer IterDomain - IterDomain* ido = new IterDomain( - new Int(0), + IterDomain* ido = IrBuilder::create( + s->container(), + IrBuilder::create(s->container(), 0), s->innerSplit() ? remainder->as() : s->factor(), ParallelType::Serial, rfactor_outer ? IterType::Reduction : IterType::Iteration, true); // broadcast // inner IterDomain - IterDomain* idi = new IterDomain( - new Int(0), + IterDomain* idi = IrBuilder::create( + s->container(), + IrBuilder::create(s->container(), 0), s->innerSplit() ? s->factor() : remainder->as(), ParallelType::Serial, rfactor_inner ? IterType::Reduction : IterType::Iteration, true); // Generate the split node - new Split(ido, idi, mapped, s->factor(), s->innerSplit()); + IrBuilder::create( + s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); // Remove mapped id from leaf IDs leaf_ids_.erase(mapped); @@ -115,14 +119,16 @@ class ReplayRFactor : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - IterDomain* merged_id = new IterDomain( - new Int(0), + IterDomain* merged_id = IrBuilder::create( + m->container(), + IrBuilder::create(m->container(), 0), merged_id_size->as(), ParallelType::Serial, rfactor_output ? IterType::Reduction : IterType::Iteration, true); - new Merge(merged_id, id_outer_mapped, id_inner_mapped); + IrBuilder::create( + m->container(), merged_id, id_outer_mapped, id_inner_mapped); // Remove inputs from the leaf IDs leaf_ids_.erase(id_outer_mapped); @@ -238,7 +244,8 @@ TensorDomain* TransformRFactor::runReplay( for (auto id : orig_td_root) { // If this is an rfactor root, it will be a reduction in this stage if (rfactor_root_axes.find(id) != rfactor_root_axes.end()) { - new_root[i] = new IterDomain( + new_root[i] = IrBuilder::create( + id->container(), id->start(), id->extent(), id->stopOffset(), @@ -248,7 +255,8 @@ TensorDomain* TransformRFactor::runReplay( // If this is not an rfactor root, but a reduction root, it should be // turned into an iteration domain } else if (id->isReduction()) { - new_root[i] = new IterDomain( + new_root[i] = IrBuilder::create( + id->container(), id->start(), id->extent(), id->stopOffset(), @@ -296,7 +304,8 @@ TensorDomain* TransformRFactor::runReplay( if (dom->isRFactorProduct()) rfactor_root.push_back(dom); - return new TensorDomain( + return IrBuilder::create( + orig_td->container(), new_root, rfactor_root, new_domain, @@ -400,8 +409,11 @@ TensorDomain* TransformRFactor::runReplay2( } } - return new TensorDomain( - new_root, new_domain, std::vector(new_root.size(), true)); + return IrBuilder::create( + orig_td->container(), + new_root, + new_domain, + std::vector(new_root.size(), true)); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h index 551f67905b0..593eb287d0b 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/transform_view.cpp b/torch/csrc/jit/codegen/cuda/transform_view.cpp index ea4d188c092..433e34a11eb 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_view.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -44,11 +45,31 @@ class Transform { size_t index() const { return index_; } + + size_t originalIndex() const { + return original_index_; + } + + size_t newIndex() const { + return new_index_; + } + virtual ~Transform() = default; protected: - Transform(size_t index) : index_(index) {} + Transform(const ViewIndexState& state, size_t index) + : index_(index), + original_index_(state.original_view_index), + new_index_(Transform::computeNewIndex(state)) {} + const size_t index_ = 0; + const size_t original_index_ = 0; + const size_t new_index_ = 0; + + static size_t computeNewIndex(const ViewIndexState& state) { + return state.original_view_index - state.trivial_reduction_offset + + state.split_offset - state.merge_offset + state.broadcast_offset; + } }; //! Base class for all view tranformations - Merge, Split, Keep @@ -61,9 +82,11 @@ class ViewTransform : public Transform { std::vector& rfactor_domain) = 0; ~ViewTransform() override = default; + virtual bool isOriginalAxisDynamic() const = 0; + protected: ViewTransform(const ViewIndexState& state) - : Transform(ViewTransform::computeIndex(state)) {} + : Transform(state, ViewTransform::computeIndex(state)) {} static size_t computeIndex(const ViewIndexState& state) { return state.original_view_index - state.trivial_reduction_offset; @@ -71,6 +94,7 @@ class ViewTransform : public Transform { }; namespace { +typedef std::vector Sizes; const size_t kEmptyAxis = 0; const size_t kSingletonAxis = 1; @@ -86,6 +110,10 @@ class MergeTransform final : public ViewTransform { << std::endl; } + bool isOriginalAxisDynamic() const override { + return false; + } + void createRfactorDomain( const std::vector& new_root_domain, std::vector& rfactor_domain) override { @@ -108,14 +136,15 @@ class MergeTransform final : public ViewTransform { auto merged_extent = mul(merged_id->extent(), new_root_domain[index_ + 1]->extent()); - auto new_merged_id = new IterDomain( - new Int(0), + auto new_merged_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), merged_extent, ParallelType::Serial, IterType::Iteration, true); - new Merge(new_merged_id, merged_id, new_root_domain[index_ + 1]); + IrBuilder::create( + new_merged_id, merged_id, new_root_domain[index_ + 1]); rfactor_domain.push_back(new_merged_id); } @@ -140,6 +169,10 @@ class SplitTransform final : public ViewTransform { << " ARG: " << split_factor_ << std::endl; } + bool isOriginalAxisDynamic() const override { + return false; + } + void createRfactorDomain( const std::vector& new_root_domain, std::vector& rfactor_domain) override { @@ -150,7 +183,7 @@ class SplitTransform final : public ViewTransform { "\t Domain Size:\t", new_root_domain.size()); - auto factor = new Int(split_factor_); + auto factor = IrBuilder::create(split_factor_); IterDomain* id = nullptr; if (is_last_axis_rfactor_) { @@ -164,18 +197,22 @@ class SplitTransform final : public ViewTransform { Val* remainder = ceilDiv(id->extent(), factor); // outer loop IterDomain - IterDomain* factor_id = new IterDomain( - new Int(0), factor, id->getParallelType(), id->getIterType(), true); + IterDomain* factor_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + factor, + id->getParallelType(), + id->getIterType(), + true); // inner loop IterDomain - IterDomain* remainder_id = new IterDomain( - new Int(0), + IterDomain* remainder_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), remainder->as(), ParallelType::Serial, IterType::Iteration, true); - new Split(factor_id, remainder_id, id, factor, false); + IrBuilder::create(factor_id, remainder_id, id, factor, false); rfactor_domain.push_back(factor_id); rfactor_domain.push_back(remainder_id); @@ -195,6 +232,10 @@ class KeepTransform final : public ViewTransform { output << "Keep Index: " << index_ << std::endl; } + bool isOriginalAxisDynamic() const override { + return true; + } + void createRfactorDomain( const std::vector& new_root_domain, std::vector& rfactor_domain) override { @@ -214,17 +255,11 @@ class KeepTransform final : public ViewTransform { class BroadcastTransform final : public Transform { public: BroadcastTransform(const ViewIndexState& state) - : Transform(BroadcastTransform::computeIndex(state)) {} + : Transform(state, Transform::computeNewIndex(state)) {} void toString(std::stringstream& output) const override { output << "Bcast Index: " << index_ << std::endl; } - - private: - static size_t computeIndex(const ViewIndexState& state) { - return state.original_view_index - state.trivial_reduction_offset + - state.split_offset - state.merge_offset + state.broadcast_offset; - } }; //! For any implicit broadcast dimensions in the original view, we remove @@ -232,7 +267,7 @@ class BroadcastTransform final : public Transform { class TrivialReductionTransform final : public Transform { public: TrivialReductionTransform(const ViewIndexState& state) - : Transform(TrivialReductionTransform::computeIndex(state)) {} + : Transform(state, TrivialReductionTransform::computeIndex(state)) {} void toString(std::stringstream& output) const override { output << "1-Red Index: " << index_ << std::endl; @@ -249,10 +284,11 @@ class TrivialReductionTransform final : public Transform { class AnalyzeViewTransformation { public: AnalyzeViewTransformation( - const std::vector root_domain, - const std::vector& original_view, - const std::vector& new_view) - : root_domain_(root_domain), + const Sizes& original_view, + const Sizes& new_view, + std::vector root_domain = {}) + : default_implicit_broadcast_(root_domain.empty()), + root_domain_(root_domain), original_view_(original_view), new_view_(new_view), transform_view_(original_view) { @@ -264,6 +300,24 @@ class AnalyzeViewTransformation { TORCH_INTERNAL_ASSERT(kOriginalNumElements == kNewNumElements); } + AnalyzeViewConstraint constraint() { + findTransformation(); + TORCH_INTERNAL_ASSERT( + validate(), + "Analyze View Transformation failed to find valid transformation.\n", + toString()); + std::vector original_constraint( + original_view_.begin(), original_view_.end()); + std::vector new_constraint(new_view_.begin(), new_view_.end()); + for (auto& vt : view_transforms_) { + if (vt->isOriginalAxisDynamic()) { + original_constraint[vt->originalIndex()] = -1; + new_constraint[vt->newIndex()] = -1; + } + } + return {original_constraint, new_constraint}; + } + AnalyzeViewResult run() { findTransformation(); TORCH_INTERNAL_ASSERT( @@ -382,6 +436,15 @@ class AnalyzeViewTransformation { return true; } + bool isImplicitBroadcast(size_t original_view_index) const { + if (default_implicit_broadcast_) { + return original_view_[original_view_index] == 1; + } else { + TORCH_INTERNAL_ASSERT(!root_domain_.empty()); + return root_domain_[original_view_index]->isImplicitBroadcast(); + } + } + //! This utility class merges a fixed set of axes together //! according to some invariant. Implicit broadcast axes cannot be //! merged with standard iterDomains, so they are handled separately @@ -400,8 +463,7 @@ class AnalyzeViewTransformation { bool any_merge = false; for (size_t idx = 0; idx < num_merge_axes_; ++idx) { - if (avt_->root_domain_[state_.original_view_index] - ->isImplicitBroadcast()) { + if (avt_->isImplicitBroadcast(state_.original_view_index)) { avt_->addTrivialReductionTransform(); } else { avt_->addMergeTransform( @@ -603,9 +665,10 @@ class AnalyzeViewTransformation { std::vector> trivial_reduction_transforms_; + bool default_implicit_broadcast_ = true; const std::vector root_domain_; - const std::vector& original_view_; - const std::vector& new_view_; + const Sizes& original_view_; + const Sizes& new_view_; // transform_view is a mutable view and is initialized with the original_view. // It is used to track the current state of the original tensor domain. @@ -622,7 +685,7 @@ class AnalyzeViewTransformation { // If transform size != original size for an axis, then the transformation // uses the last rfactor domain. Otherwise, it is a root domain // transformation. - std::vector transform_view_; + Sizes transform_view_; }; //! Create new TensorDomain with a modified rfactor domain using the specified @@ -644,7 +707,7 @@ TensorDomain* createViewDomain( t->createRfactorDomain(new_root_domain, rfactor_domain); } - return new TensorDomain( + return IrBuilder::create( new_root_domain, rfactor_domain, rfactor_domain, @@ -652,11 +715,19 @@ TensorDomain* createViewDomain( } //! Infer -1 value in new view sizes from original view sizes -std::vector inferNewViewShape( - const std::vector& original_view, +std::pair inferNewViewShape( + const std::vector& original_sizes, const std::vector& new_sizes) { - std::vector new_view(new_sizes.size()); + bool valid_original_sizes = std::all_of( + original_sizes.begin(), original_sizes.end(), [](int64_t dim) { + return dim > 0; + }); + TORCH_INTERNAL_ASSERT(valid_original_sizes); + Sizes original_view(original_sizes.begin(), original_sizes.end()); + Sizes new_view(new_sizes.size()); + + // TODO: refactor int64_t dynamic_index = -1; size_t new_size_num_elements = 1; for (size_t idx = 0; idx < new_sizes.size(); ++idx) { @@ -665,6 +736,7 @@ std::vector inferNewViewShape( dynamic_index == -1, "Only one dimension can by inferred.") dynamic_index = idx; } else { + TORCH_INTERNAL_ASSERT(new_sizes[idx] > 0); new_size_num_elements *= new_sizes[idx]; new_view[idx] = new_sizes[idx]; } @@ -676,7 +748,7 @@ std::vector inferNewViewShape( new_view[dynamic_index] = kNumElements / new_size_num_elements; } - return new_view; + return {original_view, new_view}; } } // namespace @@ -690,22 +762,24 @@ AnalyzeViewResult analyzeView( FUSER_PERF_SCOPE("analyzeView"); TORCH_INTERNAL_ASSERT( tv->getMaybeRFactorDomain().size() == original_sizes.size()); - - bool valid_original_sizes = std::all_of( - original_sizes.begin(), original_sizes.end(), [](int64_t dim) { - return dim > 0; - }); - - TORCH_INTERNAL_ASSERT(valid_original_sizes); - - std::vector original_view( - original_sizes.begin(), original_sizes.end()); - auto new_view = inferNewViewShape(original_view, new_sizes); + auto sizes = inferNewViewShape(original_sizes, new_sizes); AnalyzeViewTransformation analyzer( - tv->getRootDomain(), original_view, new_view); + sizes.first /* original_view */, + sizes.second /* new_view */, + tv->getRootDomain()); return analyzer.run(); } +AnalyzeViewConstraint analyzeViewConstraint( + const std::vector& original_sizes, + const std::vector& new_sizes) { + FUSER_PERF_SCOPE("analyzeViewConstraint"); + auto sizes = inferNewViewShape(original_sizes, new_sizes); + AnalyzeViewTransformation analyzer( + sizes.first /* original_view */, sizes.second /* new_view */); + return analyzer.constraint(); +} + //! Create new TensorDomain with a modified rfactor domain using the specified //! view transformations TensorDomain* transformView( diff --git a/torch/csrc/jit/codegen/cuda/transform_view.h b/torch/csrc/jit/codegen/cuda/transform_view.h index e7473a1b9b4..f8a986048be 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.h +++ b/torch/csrc/jit/codegen/cuda/transform_view.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -40,6 +40,11 @@ struct AnalyzeViewResult { std::vector> transforms; }; +struct AnalyzeViewConstraint { + std::vector original_constraint; + std::vector new_constraint; +}; + // Find the transformations necessary to convert TensorView // from original size to new size. AnalyzeViewResult analyzeView( @@ -47,6 +52,11 @@ AnalyzeViewResult analyzeView( const std::vector& original_sizes, const std::vector& new_sizes); +// Find the constraints derived from the view transformations +AnalyzeViewConstraint analyzeViewConstraint( + const std::vector& original_sizes, + const std::vector& new_sizes); + // Generate a new TensorDomain from the given view transformations. // The original root domain is kept in the new TensorDomain, // but a new rfactor domain is created from the view transformations. diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 3afb1b540b8..e883421eb1e 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -87,6 +87,9 @@ ValType promote_type(const ValType& t1, const ValType& t2) { (t1 == ValType::Scalar || t1 == ValType::NamedScalar)) { return ValType::Scalar; } + if (t1 == ValType::NamedScalar && t2 == ValType::NamedScalar) { + return ValType::Scalar; + } TORCH_CHECK(false, "Expected promotable ValTypes but got: ", t1, " and ", t2); } @@ -107,7 +110,7 @@ static const char* data_type2string(DataType t) { case DataType::Int32: return "int"; case DataType::Null: - return "nullptr"; + return "null_type"; default: break; } @@ -127,6 +130,10 @@ static const char* val_type2string(ValType t) { return "Scalar"; case ValType::NamedScalar: return "NamedScalar"; + case ValType::Predicate: + return "Predicate"; + case ValType::TensorIndex: + return "TensorIndex"; default: TORCH_INTERNAL_ASSERT(false, "No string found for val type."); } @@ -144,12 +151,38 @@ static const char* expr_type2string(ExprType t) { return "ReductionOp"; case ExprType::BroadcastOp: return "BroadcastOp"; + case ExprType::WelfordOp: + return "WelfordOp"; + case ExprType::TransposeOp: + return "TransposeOp"; case ExprType::ShiftOp: return "ShiftOp"; + case ExprType::GatherOp: + return "GatherOp"; + case ExprType::ViewOp: + return "ViewOp"; case ExprType::Split: return "Split"; case ExprType::Merge: return "Merge"; + case ExprType::Allocate: + return "Allocate"; + case ExprType::Sync: + return "Sync"; + case ExprType::InitMagicZero: + return "InitMagicZero"; + case ExprType::UpdateMagicZero: + return "UpdateMagicZero"; + case ExprType::ForLoop: + return "ForLoop"; + case ExprType::IfThenElse: + return "IfThenElse"; + case ExprType::GridReduction: + return "GridReduction"; + case ExprType::GridBroadcast: + return "GridBroadcast"; + case ExprType::GridWelford: + return "GridWelford"; default: TORCH_INTERNAL_ASSERT(false, "No string found for expr type."); } @@ -281,7 +314,6 @@ bool needFloatSuffix(BinaryOpType t) { case BinaryOpType::Atan2: case BinaryOpType::Div: case BinaryOpType::Fmod: - case BinaryOpType::Pow: return true; default: return false; @@ -522,6 +554,7 @@ static const char* supported_casts2string( case supported_switch_pair(DataType::Float, DataType::Int): case supported_switch_pair(DataType::Double, DataType::Int): return "(int64_t)"; + case supported_switch_pair(DataType::Int, DataType::Int32): case supported_switch_pair(DataType::Float, DataType::Int32): case supported_switch_pair(DataType::Double, DataType::Int32): return "(int32_t)"; diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 43aadb62006..ea7e8bd04d3 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include @@ -32,6 +32,8 @@ enum class ValType { TensorView, Scalar, NamedScalar, + Predicate, + TensorIndex, }; // Manual - The user provides the Bool value. Predicate generation is bypassed. @@ -73,6 +75,15 @@ enum class ExprType { ViewOp, Split, Merge, + Allocate, + Sync, + InitMagicZero, + UpdateMagicZero, + ForLoop, + IfThenElse, + GridReduction, + GridBroadcast, + GridWelford, }; enum class UnaryOpType { @@ -257,8 +268,11 @@ std::string stringifyThread(const ParallelType); std::string typePrefix(const DataType); // TODO: ThreadDim should be BlockDim and BlockDim should be GridDim +// Returns if parallel type is TID[x, y, z] TORCH_CUDA_CU_API bool isParallelTypeThreadDim(ParallelType); +// Returns if parallel type is BID[x, y, z] TORCH_CUDA_CU_API bool isParallelTypeBlockDim(ParallelType); +// Returns if parallel type is a grid or block parallelization dimension TORCH_CUDA_CU_API bool isParallelTypeThread(ParallelType); TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType); diff --git a/torch/csrc/jit/codegen/cuda/type_inference.cpp b/torch/csrc/jit/codegen/cuda/type_inference.cpp index 8c7d7d36a06..a8facc6a45b 100644 --- a/torch/csrc/jit/codegen/cuda/type_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/type_inference.cpp @@ -141,6 +141,7 @@ class NaiveTypePropagator { } // binary operations that forward meta info and broadcast shape: case aten::gelu_backward: + case aten::tanh_backward: case aten::mul: case aten::div: case aten::min: @@ -414,19 +415,14 @@ class NaiveTypePropagator { node->output()->setType(out_type->withDim(c10::nullopt)); break; } - /* - // TODO: Enable view in parser by detecting non-alias view operation - case aten::view: - case aten::reshape: { + case prim::unsqueeze_copy: + case prim::squeeze_copy: + case prim::reshape_copy: + case prim::view_copy: { auto out_type = node->input(0)->type()->cast(); - auto size_optional = constant_as>(node->input(1)); - TORCH_INTERNAL_ASSERT( - size_optional.has_value(), "The size parameter is required."); - auto new_size = size_optional->vec(); - node->output()->setType(out_type->withSizes(new_size)); + node->output()->setType(out_type); break; } - */ case aten::type_as: { const auto type0 = getInputTensorType(node, 0); const auto type1 = getInputTensorType(node, 1); diff --git a/torch/csrc/jit/codegen/cuda/type_promotion.cpp b/torch/csrc/jit/codegen/cuda/type_promotion.cpp index 016e8825acf..68a38e67378 100644 --- a/torch/csrc/jit/codegen/cuda/type_promotion.cpp +++ b/torch/csrc/jit/codegen/cuda/type_promotion.cpp @@ -55,13 +55,14 @@ at::native::ResultTypeState updateResultTypeState( TORCH_INTERNAL_ASSERT( !c10::isComplexType(scalar), "NvFuser does not support complex data types."); + at::native::ResultTypeState new_state = in_state; c10::ScalarType current = scalar; if (c10::isFloatingType(scalar)) { current = c10::typeMetaToScalarType(at::get_default_dtype()); } new_state.wrappedResult = - promoteTypesSkipUndefined(in_state.wrappedResult, scalar); + promoteTypesSkipUndefined(in_state.wrappedResult, current); return new_state; } @@ -195,11 +196,16 @@ std::vector promoteValues( Val* optionalCast(DataType dtype, Val* v) { TORCH_INTERNAL_ASSERT(v->getDataType().has_value()); + // Avoid casting Float/Int scalar to any corresponding FloatingPoint/Integral + // type in fusion. Instead, we cast them directly. The exception is Bool, + // which is always casted to the desired type. const bool kSameDtype = v->getDataType().value() == dtype; const bool kIsScalarFloat = !v->isA() && isFloatingPointType(dtype); + const bool kIsScalarInt = !v->isA() && isIntegralType(dtype); if (kSameDtype || - (kIsScalarFloat && isFloatingPointType(v->getDataType().value()))) { + (kIsScalarFloat && isFloatingPointType(v->getDataType().value())) || + (kIsScalarInt && isIntegralType(v->getDataType().value()))) { return v; } else { return castOp(dtype, v); diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index 67c8359b502..127078b45f7 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -143,6 +143,19 @@ void debugPrint(const c10::TensorTypePtr& type) { } #pragma clang diagnostic pop +bool is_cpu_scalar(const at::Tensor& tensor) { + return tensor.device().is_cpu() && tensor.numel() == 1 && tensor.dim() == 0; +} + +bool is_cpu_scalar(const c10::TensorType& tensor_type) { + auto opt_device = tensor_type.device(); + auto opt_dim = tensor_type.dim(); + auto opt_numel = tensor_type.numel(); + return opt_device.has_value() && opt_device.value().is_cpu() && + opt_dim.has_value() && opt_numel.has_value() && opt_dim.value() == 0 && + opt_numel.value() == 1; +} + bool isDebugDumpEnabled(DebugDumpOption option) { const static auto dump_options = parseDebugDumpOptions(); return dump_options.at(option); @@ -158,6 +171,14 @@ bool disableRNGUnrolling() { return disable_rng_unroll ? atoi(disable_rng_unroll) : false; } +std::vector getTensorSizes(TensorTypePtr const& tensor_type) { + TORCH_INTERNAL_ASSERT(tensor_type != nullptr, "Input must be a Tensor."); + auto optional_sizes = tensor_type->sizes().concrete_sizes(); + TORCH_INTERNAL_ASSERT( + optional_sizes.has_value(), "Missing size information for the tensor."); + return optional_sizes.value(); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index f56d0f8d52e..c035cdeae24 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include namespace torch { namespace jit { @@ -10,6 +11,9 @@ namespace cuda { void debugPrint(const c10::TensorTypePtr& type); +bool is_cpu_scalar(const at::Tensor& tensor); +bool is_cpu_scalar(const c10::TensorType& tensor_type); + //! Types of debug print-outs //! //! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable @@ -116,6 +120,8 @@ constexpr unsigned int switch_pair(T t1, T t2) { return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2; } +std::vector getTensorSizes(TensorTypePtr const& tensor_type); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index fa840621682..0129938568c 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -152,11 +152,11 @@ class AliasDb { * this. */ // Copy `existing`s aliasing info to `new_value`, and remove `existing`. - void replaceWithNewValue(Value* existing, Value* new_value); + TORCH_API void replaceWithNewValue(Value* existing, Value* new_value); // Copy `from`s aliasing info to `to`. - void copyValue(Value* from, Value* to); + TORCH_API void copyValue(Value* from, Value* to); // Create a new `value` that does not alias anything else. - void createValue(const Value* value); + TORCH_API void createValue(const Value* value); // Enable more precise treatment of prim::TupleConstruct. void enablePreciseTupleContainerAnalysis(); diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 9ad8934d55b..ca07e63cb61 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1308,6 +1308,10 @@ def forward(self, a) -> MyModule: obj = obj.__original_fn _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) + # some functions are explicitly marked as not supported in script mode + if hasattr(obj, "__script_unsupported"): + raise RuntimeError("TorchScript error: " + obj.__script_unsupported) + _check_directly_compile_overloaded(obj) maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: