From 5ba9896d8c848ce1f63542ac252ac2468920fcb6 Mon Sep 17 00:00:00 2001 From: jiej Date: Mon, 14 Feb 2022 16:37:45 -0800 Subject: [PATCH] Nvfuser code bump 2_1_2022 (#72127) Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: https://github.com/pytorch/pytorch/pull/61439) Pull Request resolved: https://github.com/pytorch/pytorch/pull/72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9) --- aten/src/ATen/core/interned_strings.h | 4 + benchmarks/cpp/nvfuser/batch_norm.cpp | 5 +- .../cpp/nvfuser/batch_norm_backward.cpp | 3 +- benchmarks/cpp/nvfuser/bert.cpp | 18 +- benchmarks/cpp/nvfuser/gelu_backward.cpp | 19 +- benchmarks/cpp/nvfuser/heuristic_cache.cpp | 3 +- benchmarks/cpp/nvfuser/heuristic_lookup.cpp | 3 +- benchmarks/cpp/nvfuser/instance_norm.cpp | 5 +- benchmarks/cpp/nvfuser/layer_norm.cpp | 3 +- .../cpp/nvfuser/layer_norm_backward.cpp | 3 +- benchmarks/cpp/nvfuser/shape_inference.cpp | 3 +- benchmarks/cpp/nvfuser/softmax_dropout.cpp | 7 +- benchmarks/cpp/nvfuser/utils.cpp | 10 +- test/cpp/jit/test_gpu.cpp | 4137 ++++++++++------- test/cpp/jit/test_gpu_shift.cpp | 2224 ++++++--- test/cpp/jit/test_gpu_validator.h | 26 +- test/test_jit_cuda_fuser.py | 554 ++- tools/build_variables.bzl | 12 +- torch/csrc/jit/codegen/cuda/arith.cpp | 392 +- torch/csrc/jit/codegen/cuda/arith.h | 76 +- torch/csrc/jit/codegen/cuda/codegen.cpp | 655 +-- torch/csrc/jit/codegen/cuda/codegen.h | 2 +- torch/csrc/jit/codegen/cuda/compute_at.cpp | 60 +- torch/csrc/jit/codegen/cuda/compute_at.h | 10 +- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 99 - torch/csrc/jit/codegen/cuda/compute_at_map.h | 27 - torch/csrc/jit/codegen/cuda/dispatch.cpp | 492 +- torch/csrc/jit/codegen/cuda/dispatch.h | 455 +- .../jit/codegen/cuda/evaluator_common.cpp | 113 +- .../csrc/jit/codegen/cuda/evaluator_common.h | 55 +- torch/csrc/jit/codegen/cuda/executor.cpp | 146 +- torch/csrc/jit/codegen/cuda/executor.h | 19 +- .../jit/codegen/cuda/executor_kernel_arg.cpp | 78 +- .../jit/codegen/cuda/executor_kernel_arg.h | 35 +- .../csrc/jit/codegen/cuda/executor_utils.cpp | 369 +- torch/csrc/jit/codegen/cuda/executor_utils.h | 41 +- torch/csrc/jit/codegen/cuda/expr_evaluator.h | 2 +- torch/csrc/jit/codegen/cuda/fusion.cpp | 289 +- torch/csrc/jit/codegen/cuda/fusion.h | 95 +- .../jit/codegen/cuda/fusion_segmenter.cpp | 66 +- .../csrc/jit/codegen/cuda/fusion_segmenter.h | 8 +- torch/csrc/jit/codegen/cuda/graph_fuser.cpp | 338 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 1349 +++--- torch/csrc/jit/codegen/cuda/index_compute.h | 113 +- .../codegen/cuda/index_reference_replay.cpp | 82 +- .../jit/codegen/cuda/index_reference_replay.h | 15 +- .../csrc/jit/codegen/cuda/instrumentation.cpp | 2 +- torch/csrc/jit/codegen/cuda/interface.cpp | 374 +- torch/csrc/jit/codegen/cuda/interface.h | 2 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 136 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 148 +- .../{kernel_ir_builder.cpp => ir_builder.cpp} | 220 +- torch/csrc/jit/codegen/cuda/ir_builder.h | 127 + torch/csrc/jit/codegen/cuda/ir_cloner.cpp | 49 +- torch/csrc/jit/codegen/cuda/ir_cloner.h | 21 +- torch/csrc/jit/codegen/cuda/ir_container.cpp | 279 ++ torch/csrc/jit/codegen/cuda/ir_container.h | 174 + torch/csrc/jit/codegen/cuda/ir_graphviz.cpp | 5 +- torch/csrc/jit/codegen/cuda/ir_graphviz.h | 2 +- .../jit/codegen/cuda/ir_interface_nodes.h | 86 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 124 +- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 399 +- torch/csrc/jit/codegen/cuda/ir_iostream.h | 83 +- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 355 +- torch/csrc/jit/codegen/cuda/ir_printer.h | 2 +- torch/csrc/jit/codegen/cuda/ir_utils.cpp | 88 +- torch/csrc/jit/codegen/cuda/ir_utils.h | 9 +- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 136 +- torch/csrc/jit/codegen/cuda/iter_visitor.h | 55 +- torch/csrc/jit/codegen/cuda/kernel.cpp | 199 +- torch/csrc/jit/codegen/cuda/kernel.h | 120 +- torch/csrc/jit/codegen/cuda/kernel_cache.cpp | 47 +- torch/csrc/jit/codegen/cuda/kernel_cache.h | 2 +- .../codegen/cuda/kernel_expr_evaluator.cpp | 44 +- .../jit/codegen/cuda/kernel_expr_evaluator.h | 15 +- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 531 +-- torch/csrc/jit/codegen/cuda/kernel_ir.h | 1196 +---- .../csrc/jit/codegen/cuda/kernel_ir_builder.h | 131 - .../jit/codegen/cuda/kernel_ir_dispatch.cpp | 180 + .../jit/codegen/cuda/kernel_ir_dispatch.h | 118 + .../jit/codegen/cuda/kernel_ir_printer.cpp | 451 -- .../csrc/jit/codegen/cuda/kernel_ir_printer.h | 129 - torch/csrc/jit/codegen/cuda/lower2device.cpp | 537 +-- torch/csrc/jit/codegen/cuda/lower2device.h | 45 +- .../jit/codegen/cuda/lower_alias_memory.cpp | 119 +- .../jit/codegen/cuda/lower_alias_memory.h | 5 +- .../jit/codegen/cuda/lower_allocation.cpp | 432 +- .../csrc/jit/codegen/cuda/lower_allocation.h | 7 +- .../jit/codegen/cuda/lower_double_buffer.cpp | 508 ++ .../jit/codegen/cuda/lower_double_buffer.h | 142 + .../csrc/jit/codegen/cuda/lower_expr_sort.cpp | 4 +- .../codegen/cuda/lower_fusion_simplifier.cpp | 119 + .../codegen/cuda/lower_fusion_simplifier.h | 26 + torch/csrc/jit/codegen/cuda/lower_index.cpp | 215 +- torch/csrc/jit/codegen/cuda/lower_index.h | 49 +- .../jit/codegen/cuda/lower_insert_syncs.cpp | 619 ++- .../jit/codegen/cuda/lower_insert_syncs.h | 40 +- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 62 +- torch/csrc/jit/codegen/cuda/lower_loops.h | 11 +- .../jit/codegen/cuda/lower_magic_zero.cpp | 113 +- .../csrc/jit/codegen/cuda/lower_magic_zero.h | 6 +- .../cuda/lower_misaligned_vectorization.cpp | 330 +- .../cuda/lower_misaligned_vectorization.h | 7 +- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 143 +- torch/csrc/jit/codegen/cuda/lower_predicate.h | 16 +- .../jit/codegen/cuda/lower_replace_size.cpp | 288 ++ .../jit/codegen/cuda/lower_replace_size.h | 25 + torch/csrc/jit/codegen/cuda/lower_shift.cpp | 270 +- torch/csrc/jit/codegen/cuda/lower_shift.h | 36 +- .../codegen/cuda/lower_thread_predicate.cpp | 48 +- .../jit/codegen/cuda/lower_thread_predicate.h | 6 +- .../codegen/cuda/lower_trivial_broadcast.cpp | 119 + .../codegen/cuda/lower_trivial_broadcast.h | 51 + .../codegen/cuda/lower_trivial_reductions.cpp | 25 +- .../codegen/cuda/lower_trivial_reductions.h | 16 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 101 +- torch/csrc/jit/codegen/cuda/lower_unroll.h | 24 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 382 +- torch/csrc/jit/codegen/cuda/lower_utils.h | 137 +- .../jit/codegen/cuda/lower_validation.cpp | 23 +- .../csrc/jit/codegen/cuda/lower_validation.h | 2 +- .../jit/codegen/cuda/lower_warp_reduce.cpp | 184 +- .../csrc/jit/codegen/cuda/lower_warp_reduce.h | 2 +- torch/csrc/jit/codegen/cuda/manager.cpp | 20 + torch/csrc/jit/codegen/cuda/manager.h | 2 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 416 +- torch/csrc/jit/codegen/cuda/mutator.h | 2 +- .../jit/codegen/cuda/non_divisible_split.h | 2 +- torch/csrc/jit/codegen/cuda/ops/alias.cpp | 115 + torch/csrc/jit/codegen/cuda/ops/alias.h | 38 + torch/csrc/jit/codegen/cuda/ops/all_ops.h | 1 + torch/csrc/jit/codegen/cuda/ops/composite.cpp | 114 +- torch/csrc/jit/codegen/cuda/ops/composite.h | 8 +- .../jit/codegen/cuda/ops/normalization.cpp | 67 +- .../csrc/jit/codegen/cuda/ops/normalization.h | 2 +- .../codegen/cuda/parallel_dimension_map.cpp | 51 +- .../jit/codegen/cuda/parallel_dimension_map.h | 6 +- .../jit/codegen/cuda/parallel_type_bitmap.h | 2 +- torch/csrc/jit/codegen/cuda/parser.cpp | 298 +- torch/csrc/jit/codegen/cuda/parser.h | 2 +- .../jit/codegen/cuda/partial_split_map.cpp | 28 +- .../csrc/jit/codegen/cuda/partial_split_map.h | 6 +- torch/csrc/jit/codegen/cuda/partition.cpp | 69 +- torch/csrc/jit/codegen/cuda/partition.h | 2 +- .../jit/codegen/cuda/predicate_compute.cpp | 258 +- .../csrc/jit/codegen/cuda/predicate_compute.h | 39 +- .../csrc/jit/codegen/cuda/reference_tensor.h | 2 +- .../csrc/jit/codegen/cuda/root_domain_map.cpp | 64 +- torch/csrc/jit/codegen/cuda/root_domain_map.h | 19 +- .../codegen/cuda/runtime/block_sync_atomic.cu | 6 +- .../codegen/cuda/runtime/grid_reduction.cu | 2 +- .../jit/codegen/cuda/runtime/grid_sync.cu | 6 +- .../csrc/jit/codegen/cuda/runtime/helpers.cu | 16 + torch/csrc/jit/codegen/cuda/runtime/tensor.cu | 10 + .../csrc/jit/codegen/cuda/runtime/welford.cu | 16 +- .../codegen/cuda/scheduler/normalization.cpp | 52 +- .../jit/codegen/cuda/scheduler/pointwise.cpp | 16 +- .../jit/codegen/cuda/scheduler/pointwise.h | 5 + .../jit/codegen/cuda/scheduler/reduction.cpp | 121 +- .../cuda/scheduler/reduction_heuristic.h | 34 +- .../cuda/scheduler/reduction_utils.cpp | 393 +- .../jit/codegen/cuda/scheduler/registry.cpp | 119 +- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 38 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 194 +- .../csrc/jit/codegen/cuda/transform_iter.cpp | 8 +- torch/csrc/jit/codegen/cuda/transform_iter.h | 2 +- .../jit/codegen/cuda/transform_replay.cpp | 38 +- .../csrc/jit/codegen/cuda/transform_replay.h | 2 +- .../jit/codegen/cuda/transform_rfactor.cpp | 38 +- .../csrc/jit/codegen/cuda/transform_rfactor.h | 2 +- .../csrc/jit/codegen/cuda/transform_view.cpp | 164 +- torch/csrc/jit/codegen/cuda/transform_view.h | 12 +- torch/csrc/jit/codegen/cuda/type.cpp | 37 +- torch/csrc/jit/codegen/cuda/type.h | 16 +- .../csrc/jit/codegen/cuda/type_inference.cpp | 16 +- .../csrc/jit/codegen/cuda/type_promotion.cpp | 10 +- torch/csrc/jit/codegen/cuda/utils.cpp | 21 + torch/csrc/jit/codegen/cuda/utils.h | 8 +- torch/csrc/jit/ir/alias_analysis.h | 6 +- torch/jit/_script.py | 4 + 180 files changed, 15144 insertions(+), 11754 deletions(-) rename torch/csrc/jit/codegen/cuda/{kernel_ir_builder.cpp => ir_builder.cpp} (50%) create mode 100644 torch/csrc/jit/codegen/cuda/ir_builder.h create mode 100644 torch/csrc/jit/codegen/cuda/ir_container.cpp create mode 100644 torch/csrc/jit/codegen/cuda/ir_container.h delete mode 100644 torch/csrc/jit/codegen/cuda/kernel_ir_builder.h create mode 100644 torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp create mode 100644 torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h delete mode 100644 torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp delete mode 100644 torch/csrc/jit/codegen/cuda/kernel_ir_printer.h create mode 100644 torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp create mode 100644 torch/csrc/jit/codegen/cuda/lower_double_buffer.h create mode 100644 torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp create mode 100644 torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h create mode 100644 torch/csrc/jit/codegen/cuda/lower_replace_size.cpp create mode 100644 torch/csrc/jit/codegen/cuda/lower_replace_size.h create mode 100644 torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp create mode 100644 torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h create mode 100644 torch/csrc/jit/codegen/cuda/ops/alias.cpp create mode 100644 torch/csrc/jit/codegen/cuda/ops/alias.h diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 36fb0f91e4c..b2d6a43731f 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -45,6 +45,10 @@ namespace c10 { _(prim, CudaFusionGuard) \ _(prim, FunctionalGraph) \ _(prim, add_optional) \ + _(prim, view_copy) \ + _(prim, reshape_copy) \ + _(prim, squeeze_copy) \ + _(prim, unsqueeze_copy) \ _(prim, DifferentiableGraph) \ _(prim, TensorExprGroup) \ _(prim, TensorExprDynamicGroup) \ diff --git a/benchmarks/cpp/nvfuser/batch_norm.cpp b/benchmarks/cpp/nvfuser/batch_norm.cpp index ef6bdd667d6..57e889b19fb 100644 --- a/benchmarks/cpp/nvfuser/batch_norm.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -44,8 +45,8 @@ static void setupBatchNorm(Fusion* fusion, DataType dtype) { bias = castOp(DataType::Float, bias); } - auto momentum_ptr = new Double(kMomentum); - auto eps_ptr = new Double(kEps); + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); auto result = batch_norm( input, diff --git a/benchmarks/cpp/nvfuser/batch_norm_backward.cpp b/benchmarks/cpp/nvfuser/batch_norm_backward.cpp index e4a9fdcb034..77a09564de5 100644 --- a/benchmarks/cpp/nvfuser/batch_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm_backward.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -49,7 +50,7 @@ static void setupBatchNorm_BWD(Fusion* fusion, DataType dtype) { grad_output = castOp(DataType::Float, grad_output); } - auto eps_ptr = new Double(kEps); + auto eps_ptr = IrBuilder::create(kEps); auto result = batch_norm_backward( input, diff --git a/benchmarks/cpp/nvfuser/bert.cpp b/benchmarks/cpp/nvfuser/bert.cpp index f8a389331ee..a1dd58d5646 100644 --- a/benchmarks/cpp/nvfuser/bert.cpp +++ b/benchmarks/cpp/nvfuser/bert.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -36,7 +37,7 @@ static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) { fusion->addInput(tv1); // TODO: should be input - auto d16 = new Double(1.0); + auto d16 = IrBuilder::create(1.0); if (is_fp16) { tv0 = castOp(DataType::Float, tv0); @@ -47,7 +48,7 @@ static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) { auto tv3 = add(tv2, tv0); auto tv10 = softmax(tv3, 3); - auto dropout_tvs = dropout(tv10, new Double(0.9)); + auto dropout_tvs = dropout(tv10, IrBuilder::create(0.9)); auto tv12 = dropout_tvs.mask; auto tv14 = dropout_tvs.output; @@ -83,9 +84,9 @@ static void setupDivMaxSoftmaxDropoutBackward(Fusion* fusion, DataType dtype) { } // TODO: should be inputs - auto d32 = new Double(1.0); + auto d32 = IrBuilder::create(1.0); // fusion->addInput(d32); - auto d33 = new Double(2.0); + auto d33 = IrBuilder::create(2.0); // fusion->addInput(d33); auto tv4 = mul(tv2, tv3); @@ -252,14 +253,15 @@ static void setupBiasDropoutAddLayernormFwd(Fusion* fusion, DataType dtype) { auto tv5 = broadcast(tv4, {true, true, false}); auto tv6 = add(tv3, tv5); - auto dropout_outs = dropout(tv6, new Double(0.9)); + auto dropout_outs = dropout(tv6, IrBuilder::create(0.9)); auto tv8 = dropout_outs.output; auto tv10 = dropout_outs.mask; auto tv11 = add(tv10, tv2); - auto layer_norm_outs = layer_norm(tv11, 1, tv0, tv1, new Double(1e-5)); + auto layer_norm_outs = + layer_norm(tv11, 1, tv0, tv1, IrBuilder::create(1e-5)); auto tv14 = layer_norm_outs.output; auto tv21 = layer_norm_outs.mean; auto tv26 = layer_norm_outs.invstd; @@ -481,7 +483,7 @@ static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) { tv1 = castOp(DataType::Float, tv1); tv8 = castOp(DataType::Float, tv8); } - auto d36 = mul(new Double(1.0), tv1->axis(2)->extent()); + auto d36 = mul(IrBuilder::create(1.0), tv1->axis(2)->extent()); auto d47 = unaryOp(UnaryOpType::Reciprocal, d36); auto tv9 = broadcast(tv5, {true, true, false}); @@ -583,7 +585,7 @@ static void setupBiasDropoutAddLayernormBwd3(Fusion* fusion, DataType dtype) { } // Uncertain this is the right value, but going for it anyways - auto d34 = div(new Double(1.0), tv0->axis(2)->extent()); + auto d34 = div(IrBuilder::create(1.0), tv0->axis(2)->extent()); auto tv25 = mul(tv21, tv0); auto tv26 = mul(tv25, d34); diff --git a/benchmarks/cpp/nvfuser/gelu_backward.cpp b/benchmarks/cpp/nvfuser/gelu_backward.cpp index 9d53d9c2759..f1811795462 100644 --- a/benchmarks/cpp/nvfuser/gelu_backward.cpp +++ b/benchmarks/cpp/nvfuser/gelu_backward.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -41,23 +42,23 @@ static void setupFusion(Fusion* fusion) { auto t5 = castOp(DataType::Float, t4); auto t6 = broadcast(t3, {true, true, false}); auto t7 = add(t6, t5); - auto t8 = mul(t7, new Double(k_079)); - auto t9 = mul(t7, new Double(k_004)); + auto t8 = mul(t7, IrBuilder::create(k_079)); + auto t9 = mul(t7, IrBuilder::create(k_004)); auto t10 = mul(t9, t7); - auto t11 = add(t10, new Int(1)); + auto t11 = add(t10, IrBuilder::create(1)); auto t12 = mul(t8, t11); auto t13 = unaryOp(UnaryOpType::Tanh, t12); - auto t14 = mul(t7, new Double(0.5)); + auto t14 = mul(t7, IrBuilder::create(0.5)); auto t15 = mul(t13, t13); auto t16 = unaryOp(UnaryOpType::Neg, t15); - auto t17 = add(t16, new Int(1)); - auto t18 = mul(t7, new Double(k_010)); + auto t17 = add(t16, IrBuilder::create(1)); + auto t18 = mul(t7, IrBuilder::create(k_010)); auto t19 = mul(t18, t7); - auto t20 = add(t19, new Double(k_079)); + auto t20 = add(t19, IrBuilder::create(k_079)); auto t21 = mul(t17, t20); auto t22 = mul(t14, t21); - auto t23 = add(t13, new Int(1)); - auto t24 = mul(t23, new Double(0.5)); + auto t23 = add(t13, IrBuilder::create(1)); + auto t24 = mul(t23, IrBuilder::create(0.5)); auto t25 = add(t22, t24); auto t26 = mul(t25, t1); diff --git a/benchmarks/cpp/nvfuser/heuristic_cache.cpp b/benchmarks/cpp/nvfuser/heuristic_cache.cpp index 22b8ec4ce97..65f850a016c 100644 --- a/benchmarks/cpp/nvfuser/heuristic_cache.cpp +++ b/benchmarks/cpp/nvfuser/heuristic_cache.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -129,7 +130,7 @@ static auto getLayerForwardNormRuntime( Fusion& fusion = *fusion_ptr.get(); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); auto input = makeSymbolicTensor(shape.size()); fusion.addInput(input); diff --git a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp index 22b8ec4ce97..65f850a016c 100644 --- a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp +++ b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -129,7 +130,7 @@ static auto getLayerForwardNormRuntime( Fusion& fusion = *fusion_ptr.get(); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); auto input = makeSymbolicTensor(shape.size()); fusion.addInput(input); diff --git a/benchmarks/cpp/nvfuser/instance_norm.cpp b/benchmarks/cpp/nvfuser/instance_norm.cpp index 395ac6c8c9c..007291d75f5 100644 --- a/benchmarks/cpp/nvfuser/instance_norm.cpp +++ b/benchmarks/cpp/nvfuser/instance_norm.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -39,8 +40,8 @@ static void setupInstanceNorm(Fusion* fusion, DataType dtype) { const bool kTraining = true; const float kMomentum = 0.1; const float kEps = 1e-5; - auto momentum_ptr = new Double(kMomentum); - auto eps_ptr = new Double(kEps); + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); auto norm = instance_norm( input, diff --git a/benchmarks/cpp/nvfuser/layer_norm.cpp b/benchmarks/cpp/nvfuser/layer_norm.cpp index c4f79b2b668..7500ac8525b 100644 --- a/benchmarks/cpp/nvfuser/layer_norm.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -24,7 +25,7 @@ static void setupLayerNorm(Fusion* fusion, DataType dtype) { const int kReductionAxis = 1; const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); // setup fusion auto input = makeContigTensor(2, dtype); diff --git a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp index 43eafcc42fb..045465e7125 100644 --- a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -22,7 +23,7 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); const int kReductionAxis = 1; - Double* eps_ptr = new Double(1e-5); + Double* eps_ptr = IrBuilder::create(1e-5); // setup fusion auto grad_out = makeContigTensor(2, dtype); diff --git a/benchmarks/cpp/nvfuser/shape_inference.cpp b/benchmarks/cpp/nvfuser/shape_inference.cpp index 33a9404b073..15acc51bb37 100644 --- a/benchmarks/cpp/nvfuser/shape_inference.cpp +++ b/benchmarks/cpp/nvfuser/shape_inference.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -151,7 +152,7 @@ static auto getLayerForwardNormRuntime( Fusion& fusion = *fusion_ptr.get(); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); auto input = makeSymbolicTensor(shape.size()); fusion.addInput(input); diff --git a/benchmarks/cpp/nvfuser/softmax_dropout.cpp b/benchmarks/cpp/nvfuser/softmax_dropout.cpp index b4890eaf8d8..828940933f4 100644 --- a/benchmarks/cpp/nvfuser/softmax_dropout.cpp +++ b/benchmarks/cpp/nvfuser/softmax_dropout.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -35,7 +36,7 @@ static void setupSoftmaxDropout( auto attention_scores = makeContigTensor(4, dtype); auto attention_mask = makeContigTensor(4, dtype); - Double* divisor = new Double(); + Double* divisor = IrBuilder::create(); fusion->addInput(attention_scores); fusion->addInput(attention_mask); @@ -49,8 +50,8 @@ static void setupSoftmaxDropout( attention_scores = div(attention_scores, divisor); attention_scores = add(attention_scores, attention_mask); auto attention_probs = softmax(attention_scores, kReductionAxis); - auto prob = new Double(kDropoutProbability); - auto scale = new Double(kScale); + auto prob = IrBuilder::create(kDropoutProbability); + auto scale = IrBuilder::create(kScale); auto dropout_results = dropout(attention_probs, prob, scale); auto output = dropout_results.output; diff --git a/benchmarks/cpp/nvfuser/utils.cpp b/benchmarks/cpp/nvfuser/utils.cpp index 053fc693908..daf2b21a053 100644 --- a/benchmarks/cpp/nvfuser/utils.cpp +++ b/benchmarks/cpp/nvfuser/utils.cpp @@ -16,8 +16,8 @@ std::string toString(ReductionParams rparams) { if (rparams.schedule_3D) { ss << "3D Schedule // " << "Outer Reduction: " - << (rparams.cross_block_outer_reduce ? "cross block / " : "") - << (rparams.cross_grid_outer_reduce ? "cross grid / " : "") + << (rparams.cross_block_outer_reduction ? "cross block / " : "") + << (rparams.cross_grid_outer_reduction ? "cross grid / " : "") << (rparams.split_grid_dim_outer_reduction ? "split grid dim / " : ""); if (rparams.batches_per_block_outer_reduction > 1 || rparams.persistent_kernel) { @@ -38,9 +38,9 @@ std::string toString(ReductionParams rparams) { } ss << " // Inner Reduction Domain: " - << (rparams.cross_block_inner_reduce ? "cross block reduction / " : "") + << (rparams.cross_block_inner_reduction ? "cross block reduction / " : "") << (rparams.pad_inner_reduction_to_warp ? "pad to warp / " : "") - << (rparams.cross_grid_inner_reduce ? "cross grid reduction / " : ""); + << (rparams.cross_grid_inner_reduction ? "cross grid reduction / " : ""); if (rparams.batches_per_block_inner_reduction > 1 || rparams.persistent_kernel) { @@ -48,7 +48,7 @@ std::string toString(ReductionParams rparams) { << " / "; } - ss << (rparams.cross_grid_inner_reduce && + ss << (rparams.cross_grid_inner_reduction && rparams.split_grid_dim_inner_reduction ? "split grid dimension / " : "") diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index f229ac2679e..b7a4489abe9 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -18,7 +19,7 @@ #include #include #include -#include +#include #include #include #include @@ -86,19 +87,95 @@ void checkIntValue( void checkIntValue( kir::ExpressionEvaluator& evaluator, - const kir::Val* val, - kir::Int::ScalarType expected_value) { + const Val* val, + Int::ScalarType expected_value) { const auto actual_value = evaluator.evaluate(val); TORCH_CHECK(actual_value.has_value()); TORCH_CHECK(actual_value.value() == expected_value); } -bool isPredicated(TensorView* tv, GpuLower& gpulw) { - auto parent_scope = gpulw.lowerValue(tv)->definition()->parentScope(); - if (parent_scope->isA()) { - return !parent_scope->predicate()->value()->isConst(); +TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) { + auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as()); + TensorView* matching_tv = nullptr; + for (auto lowered_tv : used_tvs) { + if (lowered_tv->name() == tv->name()) { + matching_tv = lowered_tv; + } + } + TORCH_INTERNAL_ASSERT(matching_tv != nullptr); + return matching_tv; +} + +class PredicatedChecker : public kir::IrVisitor { + public: + // Checks if the provided tv is written to within a non-trivial conditional + static bool isPredicated(TensorView* tv, GpuLower& gpulw) { + PredicatedChecker checker( + loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); + return checker.is_predicated_; + } + + private: + PredicatedChecker() = delete; + + PredicatedChecker(TensorView* tv, std::vector exprs) : tv_(tv) { + kir::IrVisitor::handle(exprs); + } + + using kir::IrVisitor::handle; + bool is_predicated_ = false; + bool predicated_ite_ = false; + TensorView* tv_ = nullptr; + + void handle(kir::IfThenElse* ite) final { + auto prev_ite = predicated_ite_; + predicated_ite_ = !ite->predicate()->value()->isConstScalar(); + kir::IrVisitor::handle(ite); + predicated_ite_ = prev_ite; + } + + void handle(Expr* expr) final { + if (expr->outputs().size() && expr->outputs()[0]->isA()) { + auto ti = expr->outputs()[0]->as(); + if (ti->view() == tv_) { + is_predicated_ = is_predicated_ | predicated_ite_; + } + } + kir::IrVisitor::handle(expr); + } +}; + +class UnswitchInElseChecker : public kir::IrVisitor { + public: + // Checks if there are any unswitched for loops within an else clause + static bool check(GpuLower& gpulw) { + UnswitchInElseChecker checker(gpulw.kernel()->topLevelExprs()); + return checker.found_in_else_; + } + + private: + UnswitchInElseChecker() = delete; + UnswitchInElseChecker(std::vector exprs) { + kir::IrVisitor::handle(exprs); + } + + using kir::IrVisitor::handle; + bool within_else_ = false; + bool found_in_else_ = false; + + void handle(kir::IfThenElse* ite) final { + auto prev_within_else = within_else_; + within_else_ = true; + kir::IrVisitor::handle(ite->elseBody().exprs()); + within_else_ = prev_within_else; + } + + void handle(kir::ForLoop* for_loop) final { + if (for_loop->iter_domain()->getParallelType() == ParallelType::Unswitch) { + found_in_else_ = found_in_else_ || within_else_; + } + kir::IrVisitor::handle(for_loop); } - return true; }; } // namespace @@ -110,7 +187,7 @@ bool isPredicated(TensorView* tv, GpuLower& gpulw) { // (These tests exercise IrGraphGenerator through a non-trivial IR, // to make sure that it runs w/o crashing. The actual output is not // validated) -TEST(NVFuserTest, IrGraphGenerator_CUDA) { +TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -123,10 +200,12 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv2 = add(tv0, new Double(3.141)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.141)); TensorView* tv3 = broadcast(tv0, {false, true, false, true}); - TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv3); - TensorView* tv5 = clamp(tv4, new Double(0.f), new Double(1.f)); + TensorView* tv4 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv3); + TensorView* tv5 = clamp( + tv4, IrBuilder::create(0.f), IrBuilder::create(1.f)); TensorView* tv6 = add(tv2, tv2); // Another checkpoint before adding outputs @@ -149,7 +228,7 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) { .empty()); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(-1)->parallelize(ParallelType::TIDx); @@ -162,11 +241,11 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) { .empty()); } -TEST(NVFuserTest, FusionDispatch_CUDA) { +TEST_F(NVFuserTest, FusionDispatch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* f = new Double{2.f}; + Double* f = IrBuilder::create(2.f); std::stringstream ss1, ss2, ss3; ss1 << f; ss2 << static_cast(f); @@ -177,14 +256,14 @@ TEST(NVFuserTest, FusionDispatch_CUDA) { } // Evaluate basic scalar operations with constant values -TEST(NVFuserTest, FusionExprEvalConstants_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalConstants_CUDA) { Fusion fusion; FusionGuard fg(&fusion); ExpressionEvaluator evaluator(&fusion); - auto* a = new Int(7); - auto* b = new Int(3); + auto* a = IrBuilder::create(7); + auto* b = IrBuilder::create(3); // Avoid div operation because it casts int operands to float checkIntValue(evaluator, neg(a), -7); @@ -195,17 +274,17 @@ TEST(NVFuserTest, FusionExprEvalConstants_CUDA) { } // Evaluate basic scalar operations with bound values -TEST(NVFuserTest, FusionExprEvalBindings_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalBindings_CUDA) { Fusion fusion; FusionGuard fg(&fusion); ExpressionEvaluator evaluator(&fusion); - auto* a = new Int(); - auto* b = new Int(); + auto* a = IrBuilder::create(); + auto* b = IrBuilder::create(); auto* c = add(a, b); auto* d = neg(ceilDiv(c, b)); - auto* e = new Int(0); + auto* e = IrBuilder::create(0); // trying to evaluate before binding should give empty results TORCH_CHECK(!evaluator.evaluate(a).has_value()); @@ -240,7 +319,7 @@ TEST(NVFuserTest, FusionExprEvalBindings_CUDA) { } // Evaluate expressions in a simple IR -TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -251,7 +330,7 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); @@ -296,16 +375,16 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) { } // Evaluate expressions in a more complex IR -TEST(NVFuserTest, FusionExprEvalComplex_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalComplex_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); TensorView* tv6 = add(tv0, tv3); @@ -348,7 +427,7 @@ TEST(NVFuserTest, FusionExprEvalComplex_CUDA) { } // Evaluate expressions post lowering -TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { +TEST_F(NVFuserTest, FusionExprEvalPostLower_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -359,7 +438,7 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); @@ -375,8 +454,8 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); - auto* bid_x = add(tv3->axis(0)->extent(), new Int(0)); - auto* tid_x = add(tv3->axis(-1)->extent(), new Int(0)); + auto* bid_x = add(tv3->axis(0)->extent(), IrBuilder::create(0)); + auto* tid_x = add(tv3->axis(-1)->extent(), IrBuilder::create(0)); // Lower GpuLower gpulw(&fusion); @@ -406,37 +485,39 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) { } // Kernel IR: Evaluate basic scalar operations with constant values -TEST(NVFuserTest, FusionKernelExprEvalConstants_CUDA) { - kir::Kernel kernel; - kir::IrBuilder ir_builder(&kernel); +TEST_F(NVFuserTest, FusionKernelExprEvalConstants_CUDA) { + Fusion fusion; + kir::Kernel kernel(&fusion); + FusionGuard fg((&kernel)->as()); - auto a = ir_builder.create(7); - auto b = ir_builder.create(3); - auto c = ir_builder.subExpr(a, b); - auto d = ir_builder.divExpr(a, b); - auto e = ir_builder.mulExpr(c, d); + auto a = IrBuilder::create(7); + auto b = IrBuilder::create(3); + auto c = IrBuilder::subExpr(a, b); + auto d = IrBuilder::divExpr(a, b); + auto e = IrBuilder::mulExpr(c, d); kir::ExpressionEvaluator evaluator; - checkIntValue(evaluator, ir_builder.negExpr(a), -7); - checkIntValue(evaluator, ir_builder.addExpr(a, b), 10); - checkIntValue(evaluator, ir_builder.negExpr(e), -8); - checkIntValue(evaluator, ir_builder.modExpr(a, b), 1); - checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3); + checkIntValue(evaluator, IrBuilder::negExpr(a), -7); + checkIntValue(evaluator, IrBuilder::addExpr(a, b), 10); + checkIntValue(evaluator, IrBuilder::negExpr(e), -8); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); } // Kernel IR: Evaluate basic scalar operations with bound values -TEST(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { - kir::Kernel kernel; - kir::IrBuilder ir_builder(&kernel); +TEST_F(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { + Fusion fusion; + kir::Kernel kernel(&fusion); + FusionGuard fg((&kernel)->as()); kir::ExpressionEvaluator evaluator; - auto a = ir_builder.create(c10::nullopt); - auto b = ir_builder.create(c10::nullopt); - auto c = ir_builder.addExpr(a, b); - auto d = ir_builder.negExpr(ir_builder.ceilDivExpr(c, b)); - auto e = ir_builder.create(0); + auto a = IrBuilder::create(c10::nullopt); + auto b = IrBuilder::create(c10::nullopt); + auto c = IrBuilder::addExpr(a, b); + auto d = IrBuilder::negExpr(IrBuilder::ceilDivExpr(c, b)); + auto e = IrBuilder::create(0); // trying to evaluate before binding should give empty results TORCH_CHECK(!evaluator.evaluate(a).has_value()); @@ -452,9 +533,9 @@ TEST(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { ASSERT_ANY_THROW(evaluator.bind(e, 100)); checkIntValue(evaluator, c, 10); - checkIntValue(evaluator, ir_builder.subExpr(a, b), 4); - checkIntValue(evaluator, ir_builder.modExpr(a, b), 1); - checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3); + checkIntValue(evaluator, IrBuilder::subExpr(a, b), 4); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); checkIntValue(evaluator, d, -4); // Reset the evaluation context @@ -464,13 +545,13 @@ TEST(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { evaluator.bind(b, 5); checkIntValue(evaluator, c, 7); - checkIntValue(evaluator, ir_builder.subExpr(a, b), -3); - checkIntValue(evaluator, ir_builder.modExpr(a, b), 2); - checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::subExpr(a, b), -3); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 2); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 1); checkIntValue(evaluator, d, -2); } -TEST(NVFuserTest, FusionClear_CUDA) { +TEST_F(NVFuserTest, FusionClear_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -483,7 +564,7 @@ TEST(NVFuserTest, FusionClear_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); @@ -507,14 +588,14 @@ TEST(NVFuserTest, FusionClear_CUDA) { TORCH_CHECK(fusion.inputs().empty()); TORCH_CHECK(fusion.outputs().empty()); - TORCH_CHECK(!fusion.hasReduction()); + TORCH_CHECK(ir_utils::getReductionOps(&fusion).empty()); // 3. Rebuild the IR { TensorView* tv0 = makeSymbolicTensor(3); TensorView* tv1 = makeSymbolicTensor(3); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addInput(tv0); @@ -539,7 +620,7 @@ TEST(NVFuserTest, FusionClear_CUDA) { at::Tensor input2 = at::randn_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; @@ -548,7 +629,7 @@ TEST(NVFuserTest, FusionClear_CUDA) { TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionCopy_CUDA) { +TEST_F(NVFuserTest, FusionCopy_CUDA) { Fusion original_fusion; // Create the test IR @@ -557,7 +638,7 @@ TEST(NVFuserTest, FusionCopy_CUDA) { auto tv0 = makeSymbolicTensor(3); auto tv1 = makeSymbolicTensor(3); - auto tv2 = add(tv1, new Double(2.0)); + auto tv2 = add(tv1, IrBuilder::create(2.0)); auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); original_fusion.addInput(tv0); @@ -622,7 +703,7 @@ TEST(NVFuserTest, FusionCopy_CUDA) { ASSERT_EQ(original_kernel, clone_kernel); } -TEST(NVFuserTest, FusionMove_CUDA) { +TEST_F(NVFuserTest, FusionMove_CUDA) { Fusion fusion; // Create the test IR @@ -631,7 +712,7 @@ TEST(NVFuserTest, FusionMove_CUDA) { auto tv0 = makeSymbolicTensor(3); auto tv1 = makeSymbolicTensor(3); - auto tv2 = add(tv1, new Double(2.0)); + auto tv2 = add(tv1, IrBuilder::create(2.0)); auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); fusion.addInput(tv0); @@ -692,28 +773,28 @@ TEST(NVFuserTest, FusionMove_CUDA) { ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str()); } -TEST(NVFuserTest, FusionSimpleArith_CUDA) { +TEST_F(NVFuserTest, FusionSimpleArith_CUDA) { std::stringstream ss1, ss2; Fusion fusion; FusionGuard fg(&fusion); - Double* d1 = new Double(1.f); - Double* d2 = new Double{2.f}; - Double* d3 = new Double(); + Double* d1 = IrBuilder::create(1.f); + Double* d2 = IrBuilder::create(2.f); + Double* d3 = IrBuilder::create(); // Disrupt the fusion to make sure guard works well { Fusion fusion2; FusionGuard fg(&fusion2); - Double* d1 = new Double(1.f); - Double* d2 = new Double(2.f); + Double* d1 = IrBuilder::create(1.f); + Double* d2 = IrBuilder::create(2.f); add(d1, d2); ss2 << fusion2; } - new BinaryOp(BinaryOpType::Add, d3, d1, d2); + IrBuilder::create(BinaryOpType::Add, d3, d1, d2); ss1 << fusion; TORCH_CHECK( @@ -721,22 +802,22 @@ TEST(NVFuserTest, FusionSimpleArith_CUDA) { "Error where explicit add nodes don't match implicit add nodes."); } -TEST(NVFuserTest, FusionSimpleTypePromote_CUDA) { +TEST_F(NVFuserTest, FusionSimpleTypePromote_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* d4 = new Double{4.f}; - Int* i1 = new Int{3}; + Double* d4 = IrBuilder::create(4.f); + Int* i1 = IrBuilder::create(3); auto d5 = add(d4, i1); TORCH_CHECK(d5->getDataType() == DataType::Double); } -TEST(NVFuserTest, FusionRegister_CUDA) { +TEST_F(NVFuserTest, FusionRegister_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* v1 = new Double{1.f}; - Double* v2 = new Double{2.f}; + Double* v1 = IrBuilder::create(1.f); + Double* v2 = IrBuilder::create(2.f); Val* v3 = binaryOp(BinaryOpType::Add, v1, v2); Val* v4 = binaryOp(BinaryOpType::Add, v1, v2); TORCH_CHECK(v1->name() + 1 == v2->name()); @@ -748,14 +829,18 @@ TEST(NVFuserTest, FusionRegister_CUDA) { // dummy expr with 2 outputs only for toposort test. struct DummyExpr : public Expr { ~DummyExpr() = default; - DummyExpr(Val* _outlhs, Val* _outrhs, Val* _lhs, Val* _rhs) - : Expr(ExprType::UnaryOp) // Not terribly safe... + DummyExpr( + IrBuilderPasskey passkey, + Val* _outlhs, + Val* _outrhs, + Val* _lhs, + Val* _rhs) + : Expr(passkey, ExprType::UnaryOp) // Not terribly safe... { addOutput(_outlhs); addOutput(_outrhs); addInput(_lhs); addInput(_rhs); - this->name_ = FusionGuard::getCurFusion()->registerExpr(this); } DummyExpr(const DummyExpr& other) = delete; DummyExpr& operator=(const DummyExpr& other) = delete; @@ -763,7 +848,7 @@ struct DummyExpr : public Expr { DummyExpr& operator=(DummyExpr&& other) = delete; }; -TEST(NVFuserTest, FusionTopoSort_CUDA) { +TEST_F(NVFuserTest, FusionTopoSort_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -771,23 +856,23 @@ TEST(NVFuserTest, FusionTopoSort_CUDA) { // e1: v4 = add(v3, v2) // e2: v5 = add(v2, v4) // e3: v6 = add(v5, v5) - Double* v0 = new Double{1.f}; - Double* v1 = new Double{2.f}; - Double* v2 = new Double(); - Double* v3 = new Double(); - Double* v4 = new Double(); - Double* v5 = new Double(); - Double* v6 = new Double(); + Double* v0 = IrBuilder::create(1.f); + Double* v1 = IrBuilder::create(2.f); + Double* v2 = IrBuilder::create(); + Double* v3 = IrBuilder::create(); + Double* v4 = IrBuilder::create(); + Double* v5 = IrBuilder::create(); + Double* v6 = IrBuilder::create(); std::vector inputs = {v0, v1}; for (auto val : inputs) { fusion.addInput(val); } - Expr* e0 = new DummyExpr(v3, v2, v1, v0); - Expr* e1 = new BinaryOp(BinaryOpType::Add, v4, v3, v2); - Expr* e2 = new BinaryOp(BinaryOpType::Add, v5, v2, v4); - Expr* e3 = new BinaryOp(BinaryOpType::Add, v6, v5, v5); + Expr* e0 = IrBuilder::create(v3, v2, v1, v0); + Expr* e1 = IrBuilder::create(BinaryOpType::Add, v4, v3, v2); + Expr* e2 = IrBuilder::create(BinaryOpType::Add, v5, v2, v4); + Expr* e3 = IrBuilder::create(BinaryOpType::Add, v6, v5, v5); fusion.addOutput(v2); fusion.addOutput(v3); @@ -824,7 +909,7 @@ TEST(NVFuserTest, FusionTopoSort_CUDA) { TORCH_CHECK(v6->definition()->name() == 3); } -TEST(NVFuserTest, FusionTensor_CUDA) { +TEST_F(NVFuserTest, FusionTensor_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); Fusion fusion; @@ -833,7 +918,7 @@ TEST(NVFuserTest, FusionTensor_CUDA) { { auto tensor = at::randn({2, 3, 4, 5}, options); auto tensor_type = TensorType::create(tensor); - auto fuser_tensor = new TensorView(tensor_type); + auto fuser_tensor = IrBuilder::create(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); @@ -856,7 +941,7 @@ TEST(NVFuserTest, FusionTensor_CUDA) { auto sliced_tensor = tensor.slice(1, 0, -1, 2); auto tensor_type = TensorType::create(sliced_tensor); - auto fuser_tensor = new TensorView(tensor_type); + auto fuser_tensor = IrBuilder::create(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); @@ -873,7 +958,7 @@ TEST(NVFuserTest, FusionTensor_CUDA) { auto tensor = at::randn({2, 3, 4, 5}, options); auto permuted_tensor = tensor.permute({0, 3, 1, 2}); auto tensor_type = TensorType::create(permuted_tensor); - auto fuser_tensor = new TensorView(tensor_type); + auto fuser_tensor = IrBuilder::create(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); @@ -888,15 +973,15 @@ TEST(NVFuserTest, FusionTensor_CUDA) { } } -TEST(NVFuserTest, FusionFilterVals_CUDA) { +TEST_F(NVFuserTest, FusionFilterVals_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); auto tv1 = makeSymbolicTensor(1); - auto scalar0 = new Double(0); - auto scalar1 = new Int(0); - auto scalar2 = new Int(1); + auto scalar0 = IrBuilder::create(0); + auto scalar1 = IrBuilder::create(0); + auto scalar2 = IrBuilder::create(1); const std::vector vals = {tv0, scalar0, tv1, scalar1, scalar2}; @@ -926,7 +1011,7 @@ TEST(NVFuserTest, FusionFilterVals_CUDA) { "Not expecting any results"); } -TEST(NVFuserTest, FusionTVSplit_CUDA) { +TEST_F(NVFuserTest, FusionTVSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -943,7 +1028,7 @@ TEST(NVFuserTest, FusionTVSplit_CUDA) { static_cast(outer)->lhs()->sameAs( tv->getRootDomain()[2]->extent()) && static_cast(static_cast(outer)->rhs()) - ->sameAs(new Int(2))); + ->sameAs(IrBuilder::create(2))); IterDomain* inner = static_cast(tv->axis(3)); TORCH_CHECK( @@ -952,7 +1037,7 @@ TEST(NVFuserTest, FusionTVSplit_CUDA) { static_cast(inner->extent())->value().value() == 2); } -TEST(NVFuserTest, FusionTVMerge_CUDA) { +TEST_F(NVFuserTest, FusionTVMerge_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -970,7 +1055,7 @@ TEST(NVFuserTest, FusionTVMerge_CUDA) { tv->getRootDomain()[2]->extent()); } -TEST(NVFuserTest, FusionTVReorder_CUDA) { +TEST_F(NVFuserTest, FusionTVReorder_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1020,38 +1105,43 @@ TEST(NVFuserTest, FusionTVReorder_CUDA) { TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); } -TEST(NVFuserTest, FusionEquality_CUDA) { +TEST_F(NVFuserTest, FusionEquality_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* fval1 = new Double(); + Double* fval1 = IrBuilder::create(); Double* fval1_copy = fval1; - Double* fval2 = new Double(); - Double* fone = new Double(1.0); + Double* fval2 = IrBuilder::create(); + Double* fone = IrBuilder::create(1.0); TORCH_CHECK(fval1->sameAs(fval1_copy)); TORCH_CHECK(!fval1->sameAs(fval2)); TORCH_CHECK(!fone->sameAs(fval1)); - TORCH_CHECK(fone->sameAs(new Double(1.0))); + TORCH_CHECK(fone->sameAs(IrBuilder::create(1.0))); - Int* ival1 = new Int(); + Int* ival1 = IrBuilder::create(); Int* ival1_copy = ival1; - Int* ival2 = new Int(); - Int* ione = new Int(1); + Int* ival2 = IrBuilder::create(); + Int* ione = IrBuilder::create(1); TORCH_CHECK(ival1->sameAs(ival1_copy)); TORCH_CHECK(!ival1->sameAs(ival2)); TORCH_CHECK(!ione->sameAs(ival1)); - TORCH_CHECK(ione->sameAs(new Int(1))); - - BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1); - BinaryOp* add1_copy = - new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1); - BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Double(), fval1, ival1); - - UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Double(), fval1); - UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Double(), fval2); - UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Double(), fval1); + TORCH_CHECK(ione->sameAs(IrBuilder::create(1))); + + BinaryOp* add1 = IrBuilder::create( + BinaryOpType::Add, IrBuilder::create(), fval1, ival1); + BinaryOp* add1_copy = IrBuilder::create( + BinaryOpType::Add, IrBuilder::create(), fval1, ival1); + BinaryOp* sub1 = IrBuilder::create( + BinaryOpType::Sub, IrBuilder::create(), fval1, ival1); + + UnaryOp* neg1 = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval1); + UnaryOp* neg2 = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval2); + UnaryOp* neg1_copy = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval1); TORCH_CHECK(add1->sameAs(add1_copy)); TORCH_CHECK(!add1->sameAs(sub1)); @@ -1061,22 +1151,22 @@ TEST(NVFuserTest, FusionEquality_CUDA) { TORCH_CHECK(!neg1->sameAs(neg2)); } -TEST(NVFuserTest, FusionDependency_CUDA) { +TEST_F(NVFuserTest, FusionDependency_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Double* d0 = new Double(0.f); - Double* d1 = new Double(1.f); + Double* d0 = IrBuilder::create(0.f); + Double* d1 = IrBuilder::create(1.f); auto d2 = add(d0, d1); auto d3 = add(d2, d2); - Double* d4 = new Double(4.f); - Double* d5 = new Double(5.f); + Double* d4 = IrBuilder::create(4.f); + Double* d5 = IrBuilder::create(5.f); auto d6 = add(d4, d5); - Double* d7 = new Double(7.f); - Double* d8 = new Double(8.f); + Double* d7 = IrBuilder::create(7.f); + Double* d8 = IrBuilder::create(8.f); auto d9 = add(d7, d8); auto d10 = add(d6, d9); @@ -1131,7 +1221,7 @@ TEST(NVFuserTest, FusionDependency_CUDA) { TORCH_CHECK(dep_chain.empty()); } -TEST(NVFuserTest, FusionParser_CUDA) { +TEST_F(NVFuserTest, FusionParser_CUDA) { // This test may not pass if using a custom block sync as there may // be additional calls. Skip the test as it's not specifically // relevant with block synchronizatin. @@ -1174,31 +1264,31 @@ TEST(NVFuserTest, FusionParser_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { if ((((((((((nvfuser_index_t)blockIdx.x) * 1) + 0) * 1) + 0) * 128) + ((nvfuser_index_t)threadIdx.x)) < T0.size[0])) { - constexpr nvfuser_index_t ki183 = 0; + constexpr nvfuser_index_t i33 = 0; float T5[1]; - constexpr nvfuser_index_t ki217 = 0; - T5[ki217] = 0; - constexpr nvfuser_index_t ki208 = 0; - T5[ki208] - = T1[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki183) * 1) + ki208) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; + constexpr nvfuser_index_t i45 = 0; + T5[i45] = 0; + constexpr nvfuser_index_t i41 = 0; + T5[i41] + = T1[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i33) * 1) + i41) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; float T4[1]; - constexpr nvfuser_index_t ki223 = 0; - T4[ki223] = 0; - constexpr nvfuser_index_t ki203 = 0; - T4[ki203] - = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki183) * 1) + ki203) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; + constexpr nvfuser_index_t i47 = 0; + T4[i47] = 0; + constexpr nvfuser_index_t i39 = 0; + T4[i39] + = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i33) * 1) + i39) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; float T6[1]; - constexpr nvfuser_index_t ki192 = 0; + constexpr nvfuser_index_t i37 = 0; float T2[1]; T2[0] - = T4[ki192] - * T5[ki192]; - T6[ki192] + = T4[i37] + * T5[i37]; + T6[i37] = T2[0] - * T4[ki192]; - constexpr nvfuser_index_t ki185 = 0; - T3[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki183) * 1) + ki185) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] - = T6[ki185]; + * T4[i37]; + constexpr nvfuser_index_t i35 = 0; + T3[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i33) * 1) + i35) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] + = T6[i35]; } } )"; @@ -1227,62 +1317,25 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te } FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1, input2}, lparams); auto outputs = fe.runFusion({input1, input2}, lparams); at::Tensor output_ref = input1 * input2 * input1; TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionForLoop_CUDA) { -// TODO(kir): re-enable this test -// due to the current "GpuLower guard" approach, we can only create -// kernel IR during GpuLower::lower() -#if 0 - Fusion fusion; - FusionGuard fg(&fusion); - - const auto TV0 = new TensorView( - new TensorDomain({new IterDomain(new Int(0), new Int(16))}), - DataType::Float); - const auto TV1 = new TensorView( - new TensorDomain({new IterDomain(new Int(0), new Int(16))}), - DataType::Float); - - fusion.addInput(TV0); - fusion.addInput(TV1); - - auto ID0 = new kir::IterDomain(new IterDomain(new Int(0), new Int(8))); - - TensorView* TV2 = add(TV0, TV1); - BinaryOp* op = static_cast(TV2->definition(); - fusion.addOutput(TV2); - - auto fl = new kir::ForLoop(new kir::Int(c10::nullopt), ID0, {op}); - - std::stringstream result; - std::stringstream ref; - result << fl; - ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}"; - - if (result.str().compare(ref.str()) == 0) { - std::stringstream err_msg; - err_msg << "ForLoop printing has changed or something has gone wrong. " - << result.str() << "\n does not match reference: " << ref.str() - << std::endl; - TORCH_CHECK(false, err_msg.str()); - } -#endif -} - -TEST(NVFuserTest, FusionOuterSplit_CUDA) { +TEST_F(NVFuserTest, FusionOuterSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(3); - new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0)); - TensorView* tv1 = add(tv0, new Double(2.0)); - TensorView* tv2 = add(tv1, new Double(3.0)); + IrBuilder::create( + BinaryOpType::Add, + tv0, + IrBuilder::create(0.0), + IrBuilder::create(1.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); fusion.addOutput(tv2); //[I0, I1, I2] @@ -1312,15 +1365,19 @@ TEST(NVFuserTest, FusionOuterSplit_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionCodeGen_CUDA) { +TEST_F(NVFuserTest, FusionCodeGen_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(3); - new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0)); - TensorView* tv1 = add(tv0, new Double(2.0)); - TensorView* tv2 = add(tv1, new Double(3.0)); + IrBuilder::create( + BinaryOpType::Add, + tv0, + IrBuilder::create(0.0), + IrBuilder::create(1.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); fusion.addOutput(tv2); //[I0, I1, I2] @@ -1349,13 +1406,13 @@ TEST(NVFuserTest, FusionCodeGen_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionCodeGen2_CUDA) { +TEST_F(NVFuserTest, FusionCodeGen2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(3); TensorView* tv1 = makeSymbolicTensor(3); - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addInput(tv0); @@ -1382,7 +1439,7 @@ TEST(NVFuserTest, FusionCodeGen2_CUDA) { at::Tensor input2 = at::randn_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; @@ -1391,7 +1448,7 @@ TEST(NVFuserTest, FusionCodeGen2_CUDA) { TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionSimplePWise_CUDA) { +TEST_F(NVFuserTest, FusionSimplePWise_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // dimensionality of the problem @@ -1407,7 +1464,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -1439,7 +1496,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) { at::Tensor output = at::empty_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; @@ -1448,7 +1505,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionExecKernel_CUDA) { +TEST_F(NVFuserTest, FusionExecKernel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1462,7 +1519,7 @@ TEST(NVFuserTest, FusionExecKernel_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -1490,7 +1547,7 @@ TEST(NVFuserTest, FusionExecKernel_CUDA) { at::Tensor input2 = at::ones_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); at::Tensor check = at::full({1, 128}, 4, options); @@ -1502,7 +1559,7 @@ int ceilDiv_(int a, int b) { return (a + b - 1) / b; } -TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -1517,10 +1574,10 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = add(tv1, new Double(3.0)); - TensorView* tv4 = mul(tv1, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); @@ -1538,7 +1595,8 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { tv0->computeAt(tv7, 1); - GpuLower gpulw(&fusion); + ComputeAtMap loop_map(ComputeAtMap::MappingMode::LOOP); + loop_map.build(&fusion); // The this-position of the last tensor should be zero. TORCH_CHECK( @@ -1550,11 +1608,12 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { // The position of every other tensor should be 1. for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0))); + + TORCH_CHECK(loop_map.areMapped(tv7->axis(0), tv->axis(0))); } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -1579,14 +1638,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -1600,9 +1659,9 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); @@ -1621,7 +1680,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { tv0->computeAt(tv6, 1); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1643,13 +1702,13 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { std::vector aten_outputs = {t5, t6}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 @@ -1662,7 +1721,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { TensorView* tv1 = makeSymbolicTensor(4); fusion.addInput(tv1); - TensorView* tv2 = mul(tv1, new Double(.979361)); + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); @@ -1679,7 +1738,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1700,14 +1759,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { at::Tensor cg_output = at::empty_like(t0, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 @@ -1747,7 +1806,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -1769,14 +1828,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 @@ -1788,7 +1847,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -1809,14 +1868,14 @@ TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1824,7 +1883,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -1848,26 +1907,26 @@ TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1.0)); + auto tv1 = add(tv0, IrBuilder::create(1.0)); auto tv2 = makeSymbolicTensor(1); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(3.0)); + auto tv3 = add(tv2, IrBuilder::create(3.0)); auto tv4 = add(tv1, tv3); fusion.addOutput(tv4); @@ -1899,9 +1958,6 @@ TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { auto tv5_domain_current = tv5->domain()->domain(); TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain"); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; @@ -1919,25 +1975,27 @@ TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { std::vector aten_inputs = {t0, t2, t6}; std::vector aten_outputs = {t4, t7}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1.0)); + auto tv1 = add(tv0, IrBuilder::create(1.0)); auto tv2 = makeSymbolicTensor(1); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(3.0)); + auto tv3 = add(tv2, IrBuilder::create(3.0)); auto tv4 = add(tv1, tv3); fusion.addOutput(tv4); @@ -1964,9 +2022,6 @@ TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { tv2->computeAt(tv4, -1); tv0->computeAt(tv7, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; @@ -1984,13 +2039,15 @@ TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { std::vector aten_inputs = {t0, t2, t6}; std::vector aten_outputs = {t4, t7}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -2005,10 +2062,10 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = add(tv1, new Double(3.0)); - TensorView* tv4 = mul(tv1, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); @@ -2036,14 +2093,17 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 && tv6->getMaxProducerPosition() == 1); + ComputeAtMap loop_map(ComputeAtMap::MappingMode::LOOP); + loop_map.build(&fusion); + // The position of every other tensor should be 1. for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0))); + TORCH_CHECK(loop_map.areMapped(tv7->axis(0), tv->axis(0))); } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -2068,14 +2128,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -2089,9 +2149,9 @@ TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); @@ -2110,7 +2170,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { tv0->computeWith(tv6, 1); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -2132,13 +2192,13 @@ TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { std::vector aten_outputs = {t5, t6}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 @@ -2151,7 +2211,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { TensorView* tv1 = makeSymbolicTensor(4); fusion.addInput(tv1); - TensorView* tv2 = mul(tv1, new Double(.979361)); + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); @@ -2173,7 +2233,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -2194,14 +2254,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { at::Tensor cg_output = at::empty_like(t0, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 @@ -2240,7 +2300,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -2262,14 +2322,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 @@ -2281,7 +2341,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -2302,14 +2362,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2317,7 +2377,7 @@ TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { fusion.addInput(tv0); TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -2341,14 +2401,14 @@ TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -2 @@ -2358,9 +2418,9 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv1, new Double(-2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); fusion.addOutput(tv2); fusion.addOutput(tv3); @@ -2387,10 +2447,12 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { TORCH_CHECK( tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1); + ComputeAtMap loop_map(ComputeAtMap::MappingMode::LOOP); + loop_map.build(&fusion); + // Note that tv2 is also computed at tv3. for (auto tv : {tv1, tv2}) { - TORCH_CHECK( - gpulw.caLoopMap().areMapped(tv->axis(0), computeAtTarget->axis(0))); + TORCH_CHECK(loop_map.areMapped(tv->axis(0), computeAtTarget->axis(0))); } TORCH_CHECK(tv3->getComputeAtPosition() == 0); @@ -2414,7 +2476,7 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -2422,7 +2484,7 @@ TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { } // Similar to ComputeAtMultiConsumers, but with a common consumer. -TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -2 @@ -2434,11 +2496,11 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv1, new Double(-2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); TensorView* tv4 = add(tv2, tv3); - TensorView* tv5 = mul(tv4, new Double(5.0)); + TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); fusion.addOutput(tv3); fusion.addOutput(tv4); fusion.addOutput(tv5); @@ -2492,14 +2554,14 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -1 @@ -2511,10 +2573,10 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv2, new Double(-1.0)); - TensorView* tv4 = add(tv1, new Double(4.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); + TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); TensorView* tv5 = add(tv3, tv4); fusion.addOutput(tv5); @@ -2541,7 +2603,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // All tensors should have the same dimenionality as the target for (Val* val : fusion.vals()) { - if (fusion.hasInput(val) || + if (val->isFusionInput() || val->getValType().value() != ValType::TensorView) { continue; } @@ -2555,7 +2617,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { } for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (!fusion.hasInput(tv)) { + if (!tv->isFusionInput()) { tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } @@ -2574,7 +2636,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { at::Tensor cg_output = at::empty_like(aten_input, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( @@ -2583,7 +2645,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { // Similar to the above common consumer test but adds an additional // tensor that has no common consumer with the other tensors. -TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv2 * -1 @@ -2596,12 +2658,12 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv2, new Double(-1.0)); - TensorView* tv4 = add(tv1, new Double(4.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); + TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); TensorView* tv5 = add(tv3, tv4); - TensorView* tv6 = add(tv1, new Double(6.0)); + TensorView* tv6 = add(tv1, IrBuilder::create(6.0)); fusion.addOutput(tv5); fusion.addOutput(tv6); @@ -2627,7 +2689,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // All tensors should have the same dimenionality as the target for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (fusion.hasInput(tv)) { + if (tv->isFusionInput()) { continue; } TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); @@ -2640,7 +2702,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = val->as(); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -2664,7 +2726,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -2673,7 +2735,7 @@ TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { // Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor // that does not have data dependency with the consumer. -TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv1 * -2 @@ -2686,13 +2748,13 @@ TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = mul(tv1, new Double(-2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); TensorView* tv4 = add(tv2, tv3); - TensorView* tv5 = mul(tv4, new Double(5.0)); + TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); // Notice that tv6 is not a consumer of tv4. - TensorView* tv6 = mul(tv1, new Double(6.0)); + TensorView* tv6 = mul(tv1, IrBuilder::create(6.0)); fusion.addOutput(tv3); fusion.addOutput(tv4); fusion.addOutput(tv5); @@ -2737,7 +2799,7 @@ TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { at::empty_like(aten_input, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -2822,7 +2884,7 @@ void checkIdMapped( } // namespace -TEST(NVFuserTest, FusionRootMappingBasic_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2876,7 +2938,7 @@ TEST(NVFuserTest, FusionRootMappingBasic_CUDA) { checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain()); } -TEST(NVFuserTest, FusionRootMappingRfactor_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2960,7 +3022,7 @@ TEST(NVFuserTest, FusionRootMappingRfactor_CUDA) { {true, true, false}); } -TEST(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2987,7 +3049,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { {true, false}); } -TEST(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3021,7 +3083,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); } -TEST(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3050,7 +3112,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { {true, false}); } -TEST(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3095,13 +3157,13 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { } // Reproducer of issue #749 -TEST(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {false, true}); auto tv4 = add(tv0, tv3); @@ -3153,13 +3215,13 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { } // Similar to RootMappingReductionDependency5 but with rFactor -TEST(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {false, true}); auto tv4 = add(tv0, tv3); @@ -3227,7 +3289,7 @@ TEST(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { {true, true}); } -TEST(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3265,7 +3327,9 @@ TEST(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { {false, false}); } -TEST(NVFuserTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { +TEST_F( + NVFuserTest, + FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3299,7 +3363,7 @@ TEST(NVFuserTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { {false, true}); } -TEST(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3386,7 +3450,7 @@ TEST(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { {true, false}); } -TEST(NVFuserTest, FusionRootMappingBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3426,7 +3490,7 @@ TEST(NVFuserTest, FusionRootMappingBroadcast_CUDA) { } // Reproducer of issue #723 -TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { +TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3461,7 +3525,7 @@ TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto t3 = t0; @@ -3470,13 +3534,13 @@ TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = broadcast(tv1, {true, false}); auto tv3 = broadcast(tv1, {false, true}); auto tv4 = add(tv2, tv3); @@ -3486,7 +3550,7 @@ TEST(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) { ASSERT_ANY_THROW(tv1->computeAt(tv4, 1)); } -TEST(NVFuserTest, FusionScalarInputs_CUDA) { +TEST_F(NVFuserTest, FusionScalarInputs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3495,13 +3559,13 @@ TEST(NVFuserTest, FusionScalarInputs_CUDA) { TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - Double* d0 = new Double(); + Double* d0 = IrBuilder::create(); fusion.addInput(d0); - Double* d1 = new Double(); + Double* d1 = IrBuilder::create(); fusion.addInput(d1); - Double* d2 = new Double(); + Double* d2 = IrBuilder::create(); fusion.addInput(d2); - Double* d3 = new Double(); + Double* d3 = IrBuilder::create(); fusion.addInput(d3); Val* d4 = mul(d0, d1); Val* d5 = sub(d2, d3); @@ -3524,7 +3588,7 @@ TEST(NVFuserTest, FusionScalarInputs_CUDA) { tv4->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -3568,14 +3632,14 @@ TEST(NVFuserTest, FusionScalarInputs_CUDA) { at::Scalar(fl3)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionLoopUnroll_CUDA) { +TEST_F(NVFuserTest, FusionLoopUnroll_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3589,7 +3653,7 @@ TEST(NVFuserTest, FusionLoopUnroll_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -3621,7 +3685,7 @@ TEST(NVFuserTest, FusionLoopUnroll_CUDA) { at::Tensor input1 = at::randn({129, 13, 3}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}); auto outputs = fe.runFusion({input0, input1}); TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0)))); @@ -3636,11 +3700,11 @@ Val* gen_jit_operand(std::pair desc) { return makeSymbolicTensor(2, desc.second); } else if (desc.first == ValType::Scalar) { if (desc.second == DataType::Float) { - return new Double(); + return IrBuilder::create(); } else if (desc.second == DataType::Double) { - return new Double(); + return IrBuilder::create(); } else if (desc.second == DataType::Int) { - return new Int(); + return IrBuilder::create(); } else { TORCH_CHECK(false, "Not currently supported type: ", desc.first); } @@ -3763,7 +3827,7 @@ void test_op( at::manual_seed(0); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs_ivalues); fe.runFusion(aten_inputs_ivalues, output_vect); cudaDeviceSynchronize(); @@ -3809,7 +3873,7 @@ void test_op( std::make_index_sequence{}); } -TEST(NVFuserTest, FusionUnaryOps_CUDA) { +TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { using OpTuple = std::tuple; @@ -3833,7 +3897,6 @@ TEST(NVFuserTest, FusionUnaryOps_CUDA) { OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"}, OpTuple{at::floor, UnaryOpType::Floor, "floor"}, OpTuple{at::frac, UnaryOpType::Frac, "frac"}, - // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"}, OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"}, OpTuple{at::log, UnaryOpType::Log, "log"}, OpTuple{at::log10, UnaryOpType::Log10, "log10"}, @@ -3904,7 +3967,7 @@ TEST(NVFuserTest, FusionUnaryOps_CUDA) { } } -TEST(NVFuserTest, FusionBinaryOps_CUDA) { +TEST_F(NVFuserTest, FusionBinaryOps_CUDA) { using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&); using OpTuple = std::tuple; @@ -4009,7 +4072,7 @@ TEST(NVFuserTest, FusionBinaryOps_CUDA) { } } -TEST(NVFuserTest, FusionTernaryOps_CUDA) { +TEST_F(NVFuserTest, FusionTernaryOps_CUDA) { std::vector dtypes = {DataType::Double, DataType::Float}; for (auto dtype : dtypes) { @@ -4024,9 +4087,15 @@ TEST(NVFuserTest, FusionTernaryOps_CUDA) { /*JIT Func */ [&](Val* in1) -> Val* { if (dtype == DataType::Float) { - return clamp(in1, new Double(0.f), new Double(1.f)); + return clamp( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } else { - return clamp(in1, new Double(0.f), new Double(1.f)); + return clamp( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } }, /*Output */ std::make_pair(ValType::TensorView, dtype), @@ -4043,9 +4112,15 @@ TEST(NVFuserTest, FusionTernaryOps_CUDA) { /*JIT Func */ [&](Val* in1) -> Val* { if (dtype == DataType::Float) { - return threshold(in1, new Double(0.f), new Double(1.f)); + return threshold( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } else { - return threshold(in1, new Double(0.f), new Double(1.f)); + return threshold( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); } }, /*Output */ std::make_pair(ValType::TensorView, dtype), @@ -4070,7 +4145,7 @@ TEST(NVFuserTest, FusionTernaryOps_CUDA) { } } -TEST(NVFuserTest, FusionCompoundOps_CUDA) { +TEST_F(NVFuserTest, FusionCompoundOps_CUDA) { std::vector dtypes = {DataType::Double, DataType::Float}; for (auto dtype : dtypes) { @@ -4114,7 +4189,7 @@ TEST(NVFuserTest, FusionCompoundOps_CUDA) { } } -TEST(NVFuserTest, FusionCastOps_CUDA) { +TEST_F(NVFuserTest, FusionCastOps_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4139,7 +4214,7 @@ TEST(NVFuserTest, FusionCastOps_CUDA) { const at::ArrayRef input_ivalues(inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, input_ivalues); auto outputs = fe.runFusion(input_ivalues); ref_output = at::_cast_Half(at::_cast_Double(input1)); @@ -4156,7 +4231,7 @@ TEST(NVFuserTest, FusionCastOps_CUDA) { // Start off simple, block on the outer dim // block stride + thread all reduce + unrolling on inner dim -TEST(NVFuserTest, FusionReduction1_CUDA) { +TEST_F(NVFuserTest, FusionReduction1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4165,10 +4240,13 @@ TEST(NVFuserTest, FusionReduction1_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, 128); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -4207,7 +4285,7 @@ TEST(NVFuserTest, FusionReduction1_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -4216,7 +4294,7 @@ TEST(NVFuserTest, FusionReduction1_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction2_CUDA) { +TEST_F(NVFuserTest, FusionReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4225,7 +4303,8 @@ TEST(NVFuserTest, FusionReduction2_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4278,14 +4357,14 @@ TEST(NVFuserTest, FusionReduction2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1}); testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction3_CUDA) { +TEST_F(NVFuserTest, FusionReduction3_CUDA) { // What if Z participates in the reduction with X? Fusion fusion; FusionGuard fg(&fusion); @@ -4295,7 +4374,8 @@ TEST(NVFuserTest, FusionReduction3_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4328,7 +4408,7 @@ TEST(NVFuserTest, FusionReduction3_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); auto aten_output = aten_input.to(at::kDouble).sum({1}); @@ -4337,7 +4417,7 @@ TEST(NVFuserTest, FusionReduction3_CUDA) { &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction4_CUDA) { +TEST_F(NVFuserTest, FusionReduction4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4351,7 +4431,8 @@ TEST(NVFuserTest, FusionReduction4_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2); + TensorView* tv3 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); // tv3[I0, R1] = tv2[I0, I1] TensorView* tv4 = makeSymbolicTensor(1); @@ -4393,7 +4474,7 @@ TEST(NVFuserTest, FusionReduction4_CUDA) { at::Tensor t4 = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1, t4}); auto cg_outputs = fe.runFusion({t0, t1, t4}); auto t2 = t0.add(t1); @@ -4404,7 +4485,7 @@ TEST(NVFuserTest, FusionReduction4_CUDA) { &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction5_CUDA) { +TEST_F(NVFuserTest, FusionReduction5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4413,7 +4494,8 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4431,7 +4513,7 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { tv1->axis(0)->parallelize(ParallelType::BIDy); for (auto* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { val->as()->axis(-1)->parallelize(ParallelType::TIDx); } @@ -4446,7 +4528,7 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { at::Tensor cg_output = at::empty({bidy, tidx}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -4454,7 +4536,7 @@ TEST(NVFuserTest, FusionReduction5_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduction6_CUDA) { +TEST_F(NVFuserTest, FusionReduction6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4466,10 +4548,13 @@ TEST(NVFuserTest, FusionReduction6_CUDA) { fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(2, bdimx); // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] @@ -4508,14 +4593,14 @@ TEST(NVFuserTest, FusionReduction6_CUDA) { at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1, 2}); testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultiGridReduction_CUDA) { +TEST_F(NVFuserTest, FusionMultiGridReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4540,7 +4625,7 @@ TEST(NVFuserTest, FusionMultiGridReduction_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); std::vector aten_outputs = { @@ -4548,7 +4633,7 @@ TEST(NVFuserTest, FusionMultiGridReduction_CUDA) { testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultiGridReduction2_CUDA) { +TEST_F(NVFuserTest, FusionMultiGridReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4566,7 +4651,7 @@ TEST(NVFuserTest, FusionMultiGridReduction2_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionReductionTFT_CUDA) { +TEST_F(NVFuserTest, FusionReductionTFT_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4575,7 +4660,8 @@ TEST(NVFuserTest, FusionReductionTFT_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); @@ -4613,7 +4699,7 @@ TEST(NVFuserTest, FusionReductionTFT_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -4621,7 +4707,7 @@ TEST(NVFuserTest, FusionReductionTFT_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { +TEST_F(NVFuserTest, FusionReductionOuterSplit_CUDA) { // based off FusionReduction4 Fusion fusion; FusionGuard fg(&fusion); @@ -4636,7 +4722,8 @@ TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2); + TensorView* tv3 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); // tv3[I0, R1] = tv2[I0, I1] TensorView* tv4 = makeSymbolicTensor(1); @@ -4676,7 +4763,7 @@ TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { at::Tensor t4 = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1, t4}); auto cg_outputs = fe.runFusion({t0, t1, t4}); auto t2 = t0.add(t1); @@ -4687,7 +4774,7 @@ TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) { &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBranches_CUDA) { +TEST_F(NVFuserTest, FusionBranches_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4699,7 +4786,7 @@ TEST(NVFuserTest, FusionBranches_CUDA) { fusion.addInput(tv1); fusion.addInput(tv2); - auto tv3 = add(tv0, new Double(1.0)); + auto tv3 = add(tv0, IrBuilder::create(1.0)); auto tv4 = add(tv3, tv1); auto tv5 = add(tv3, tv2); auto tv6 = add(tv4, tv5); @@ -4735,7 +4822,7 @@ TEST(NVFuserTest, FusionBranches_CUDA) { std::vector aten_inputs = {t0, t1, t2}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t3 = t0.add(1.0); @@ -4747,14 +4834,14 @@ TEST(NVFuserTest, FusionBranches_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast1_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1.5)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.5)); TensorView* tv2 = makeSymbolicTensor(2); fusion.addInput(tv2); @@ -4797,14 +4884,14 @@ TEST(NVFuserTest, FusionSimpleBCast1_CUDA) { std::vector aten_inputs = {t0, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast2_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4821,7 +4908,7 @@ TEST(NVFuserTest, FusionSimpleBCast2_CUDA) { TensorView* tv4 = makeSymbolicTensor(2); fusion.addInput(tv4); - TensorView* tv5 = sub(tv4, new Double(0.1)); + TensorView* tv5 = sub(tv4, IrBuilder::create(0.1)); TensorView* tv6 = broadcast(tv5, {true, false, false}); @@ -4856,28 +4943,30 @@ TEST(NVFuserTest, FusionSimpleBCast2_CUDA) { std::vector aten_inputs = {t0, t1, t4}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast3_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views std::vector dom; - dom.push_back(new IterDomain(new Int(0), new Int())); - dom.push_back(new IterDomain( - new Int(0), - new Int(1), + dom.push_back(IrBuilder::create( + IrBuilder::create(0), IrBuilder::create())); + dom.push_back(IrBuilder::create( + IrBuilder::create(0), + IrBuilder::create(1), ParallelType::Serial, IterType::BroadcastWithStride)); // tv0[I1, B{1}] - TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create(dom), DataType::Float); fusion.addInput(tv0); // tv1[I0, I1, I2] @@ -4908,26 +4997,28 @@ TEST(NVFuserTest, FusionSimpleBCast3_CUDA) { at::Tensor cg_output = at::empty({x, y, z}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast4_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views std::vector dom; - dom.push_back(new IterDomain( - new Int(0), - new Int(1), + dom.push_back(IrBuilder::create( + IrBuilder::create(0), + IrBuilder::create(1), ParallelType::Serial, IterType::BroadcastWithStride)); - dom.push_back(new IterDomain(new Int(0), new Int())); - TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); + dom.push_back(IrBuilder::create( + IrBuilder::create(0), IrBuilder::create())); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create(dom), DataType::Float); TensorView* tv1 = makeSymbolicTensor(3); fusion.addInput(tv0); @@ -4963,30 +5054,35 @@ TEST(NVFuserTest, FusionSimpleBCast4_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleBCast5_CUDA) { +TEST_F(NVFuserTest, FusionSimpleBCast5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); constexpr int m = 2, k = 3, n = 4; - auto zero = new Int(0); - auto M = new IterDomain(zero, new Int(m)); - auto K = new IterDomain(zero, new Int(k)); - auto N = new IterDomain(zero, new Int(n)); + auto zero = IrBuilder::create(0); + auto M = IrBuilder::create(zero, IrBuilder::create(m)); + auto K = IrBuilder::create(zero, IrBuilder::create(k)); + auto N = IrBuilder::create(zero, IrBuilder::create(n)); // Set up your input tensor views - TensorView* tv0 = - new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float); + TensorView* tv0 = IrBuilder::create( + IrBuilder::create( + std::vector({M, K}), std::vector({true, true})), + DataType::Float); // Note: IterDomain must not be reused, so K needs to be cloned. - TensorView* tv1 = new TensorView( - new TensorDomain({K->clone(), N}, {true, true}), DataType::Float); + TensorView* tv1 = IrBuilder::create( + IrBuilder::create( + std::vector({K->clone(), N}), + std::vector({true, true})), + DataType::Float); fusion.addInput(tv0); fusion.addInput(tv1); @@ -5018,21 +5114,21 @@ TEST(NVFuserTest, FusionSimpleBCast5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComplexBCast1_CUDA) { +TEST_F(NVFuserTest, FusionComplexBCast1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); int x = 2, y = 3, z = 4; auto tv0 = makeConcreteTensor({y}); - auto tv1 = div(tv0, new Double(2.0)); + auto tv1 = div(tv0, IrBuilder::create(2.0)); auto tv2 = broadcast(tv1, {false, true}); auto tv3 = makeConcreteTensor({y, z}); auto tv4 = mul(tv2, tv3); @@ -5074,21 +5170,21 @@ TEST(NVFuserTest, FusionComplexBCast1_CUDA) { std::vector aten_inputs = {t0, t3, t6}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComplexBCast2_CUDA) { +TEST_F(NVFuserTest, FusionComplexBCast2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); int x = 2, y = 3, z = 4; auto tv0 = makeConcreteTensor({y, z}); - auto tv1 = div(tv0, new Double(2.0)); + auto tv1 = div(tv0, IrBuilder::create(2.0)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = makeConcreteTensor({x, y}); @@ -5119,7 +5215,7 @@ TEST(NVFuserTest, FusionComplexBCast2_CUDA) { at::Tensor t4 = at::randn({x, y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t4}); auto cg_outputs = fe.runFusion({t0, t4}); auto t1 = t0.div(2.0); @@ -5131,7 +5227,7 @@ TEST(NVFuserTest, FusionComplexBCast2_CUDA) { &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5143,7 +5239,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1.0)); + auto tv2 = add(tv0, IrBuilder::create(1.0)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); @@ -5178,14 +5274,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) { std::vector aten_inputs = {t0, t1}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5197,7 +5293,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1.0)); + auto tv2 = add(tv0, IrBuilder::create(1.0)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); @@ -5232,14 +5328,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) { std::vector aten_inputs = {t0, t1}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5250,7 +5346,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1.0)); + auto tv2 = add(tv0, IrBuilder::create(1.0)); auto tv3 = add(tv2, tv1); fusion.addOutput(tv3); @@ -5266,14 +5362,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5283,7 +5379,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { TensorView* tv1 = makeConcreteTensor({4, 4, 8}); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = broadcast(tv2, {true, false, false}); TensorView* tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -5298,14 +5394,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5315,7 +5411,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) { TensorView* tv1 = makeSymbolicTensor(3); fusion.addInput(tv1); - TensorView* tv2 = add(tv0, new Double(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = broadcast(tv2, {true, false, true}); TensorView* tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -5336,14 +5432,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5371,7 +5467,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) { scheduleReduction(&fusion, reduction_params.value()); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}, reduction_params.value().lparams); auto cg_outputs = fe.runFusion({input0, input1}, reduction_params.value().lparams); @@ -5388,7 +5484,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) { reduction_params.value().lparams); } -TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing7_CUDA) { // Might be able to use this one without 6 as the heuristics in 6 may change // and this test is to cover the same issue. Fusion fusion; @@ -5417,15 +5513,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) { tv4->axis(0)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto at_t0 = at::randn({numel_x}, options); auto at_t1 = at::randn({numel_x, numel_y}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t1}); auto cg_outputs = fe.runFusion({at_t0, at_t1}); auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) @@ -5436,7 +5531,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing8_CUDA) { // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -5464,15 +5559,14 @@ TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) { tv4->axis(0)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto at_t0 = at::randn({numel_x}, options); auto at_t1 = at::randn({numel_x, numel_y}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t1}); auto cg_outputs = fe.runFusion({at_t0, at_t1}); auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) @@ -5483,7 +5577,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing9_CUDA) { // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -5493,7 +5587,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = mul(tv1, new Double(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); auto tv3 = makeSymbolicTensor(3); @@ -5513,7 +5607,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); auto at_t1 = at_t0.unsqueeze(-1); @@ -5525,7 +5619,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) { &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing10_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5539,7 +5633,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -5575,7 +5669,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { at::Tensor output = at::empty_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; @@ -5584,7 +5678,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedIndexing11_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5596,7 +5690,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv1, new Double(1.0)); + auto tv2 = add(tv1, IrBuilder::create(1.0)); auto tv3 = broadcast(tv2, {true, false, true, true}); auto tv4 = add(tv3, tv0); @@ -5631,7 +5725,7 @@ TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { std::vector aten_inputs = {t0, t1}; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5639,16 +5733,16 @@ TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) { } // Intended to stress the lowering of our code generator -TEST(NVFuserTest, FusionAdvancedLowering1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeConcreteTensor({9, 5}); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv1, new Double(3)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); TensorView* tv4 = sum(tv3, {1}); fusion.addOutput(tv2); @@ -5671,15 +5765,14 @@ TEST(NVFuserTest, FusionAdvancedLowering1_CUDA) { std::vector aten_outputs = {t2, t4}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5691,7 +5784,7 @@ TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { TensorView* tv2 = makeSymbolicTensor(3); fusion.addInput(tv2); - TensorView* tv3 = add(tv0, new Double(1)); + TensorView* tv3 = add(tv0, IrBuilder::create(1)); TensorView* tv4 = broadcast(tv3, {false, true}); TensorView* tv5 = add(tv4, tv1); TensorView* tv6 = add(tv5, tv2); @@ -5727,8 +5820,7 @@ TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { std::vector aten_outputs = {t6}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5736,7 +5828,7 @@ TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) { } // TODO: Complete test -TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5746,13 +5838,13 @@ TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { fusion.addInput(tv1); // [b0, i1] - auto tv2 = add(tv0, new Double(2.0)); + auto tv2 = add(tv0, IrBuilder::create(2.0)); // [i0, i1] - auto tv3 = add(tv1, new Double(3.0)); + auto tv3 = add(tv1, IrBuilder::create(3.0)); // [b0, i1] - auto tv4 = add(tv2, new Double(4.0)); + auto tv4 = add(tv2, IrBuilder::create(4.0)); // [io, i1] auto tv5 = add(tv2, tv3); @@ -5776,8 +5868,7 @@ TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { std::vector aten_outputs = {t4, t5}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5787,7 +5878,7 @@ TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) { // This excercises indexing with broadcast root axes. Non-broadcast // axes need to be preferred when propagating index exprs to root // axes. See, e.g., Index::getConsumerIndex_impl. -TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5804,9 +5895,6 @@ TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { tv4->split(0, 8); tv0->computeAt(tv4, 1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); const int bx = 10; const int by = 20; @@ -5815,6 +5903,8 @@ TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { at::Tensor t3 = at::randn({bx, by, bz}, options); std::vector aten_inputs = {t0, t3}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = @@ -5824,7 +5914,7 @@ TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedLowering5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5854,15 +5944,14 @@ TEST(NVFuserTest, FusionAdvancedLowering5_CUDA) { std::vector aten_outputs = {t3}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedLowering6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedLowering6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5902,8 +5991,7 @@ TEST(NVFuserTest, FusionAdvancedLowering6_CUDA) { std::vector aten_outputs = {t5, t7}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -5911,7 +5999,7 @@ TEST(NVFuserTest, FusionAdvancedLowering6_CUDA) { } // Test a simple Gemm but also play around with fusion executor features -TEST(NVFuserTest, FusionSimpleGemm_CUDA) { +TEST_F(NVFuserTest, FusionSimpleGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -5978,7 +6066,7 @@ TEST(NVFuserTest, FusionSimpleGemm_CUDA) { at::Tensor t1 = at::randn({K, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); // Lets specify a few bounds in launch params to make sure it works fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); @@ -5996,7 +6084,7 @@ TEST(NVFuserTest, FusionSimpleGemm_CUDA) { } // Softmax with a 1D tensor. Parallelized only with a single thread block. -TEST(NVFuserTest, FusionSoftmax1D_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax1D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6042,7 +6130,7 @@ TEST(NVFuserTest, FusionSoftmax1D_CUDA) { at::Tensor t3_output = at::empty_like(cg_output, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); fe.runFusion({t0}, {cg_output}); auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false); @@ -6051,7 +6139,7 @@ TEST(NVFuserTest, FusionSoftmax1D_CUDA) { } // Softmax with a 1D tensor with input normalization. -TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6063,8 +6151,8 @@ TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { fusion.addInput(input_tv0); // Normalize with the max value before computing exp. - TensorView* max_val_tv1 = - reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0); + TensorView* max_val_tv1 = reductionOp( + BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true}); TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); @@ -6111,7 +6199,7 @@ TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { at::Tensor t3_output = at::empty({dimx}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); @@ -6121,7 +6209,7 @@ TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { // Softmax with a 3D tensor, where the inner-most 3rd dimension is // normalized. Pallelized with multiple thread blocks. -TEST(NVFuserTest, FusionSoftmax3D_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax3D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6171,7 +6259,7 @@ TEST(NVFuserTest, FusionSoftmax3D_CUDA) { at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); @@ -6181,7 +6269,7 @@ TEST(NVFuserTest, FusionSoftmax3D_CUDA) { } // Softmax with a 3D tensor with input normalization. -TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6195,8 +6283,8 @@ TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { fusion.addInput(input_tv0); // Normalize with the max value before computing exp. - TensorView* max_val_tv1 = - reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0); + TensorView* max_val_tv1 = reductionOp( + BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true}); TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); @@ -6246,7 +6334,7 @@ TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); @@ -6254,7 +6342,7 @@ TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { +TEST_F(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6265,7 +6353,7 @@ TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { auto tv1 = sum(tv0, {1}); auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv0, new Double(1.0)); + auto tv3 = add(tv0, IrBuilder::create(1.0)); auto tv4 = mul(tv2, tv3); @@ -6280,10 +6368,7 @@ TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { } // Similar to FusionReduction but uses grid reduction -TEST(NVFuserTest, FusionGridReduction1_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridReduction1_CUDA) { const int gdimx = 32; const int bdimx = 128; @@ -6295,10 +6380,13 @@ TEST(NVFuserTest, FusionGridReduction1_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -6331,7 +6419,7 @@ TEST(NVFuserTest, FusionGridReduction1_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6341,10 +6429,7 @@ TEST(NVFuserTest, FusionGridReduction1_CUDA) { } // Same test as the above but uses BIDy and TIDx for reduction -TEST(NVFuserTest, FusionGridReduction2_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridReduction2_CUDA) { const int gdimy = 32; const int bdimx = 128; @@ -6356,10 +6441,13 @@ TEST(NVFuserTest, FusionGridReduction2_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -6391,7 +6479,7 @@ TEST(NVFuserTest, FusionGridReduction2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6400,7 +6488,7 @@ TEST(NVFuserTest, FusionGridReduction2_CUDA) { } // Same test but uses BIDy and BIDz for reduction. No TID used. -TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction3dim1_CUDA) { // Grid reductions when there aren't any threads are serial reductions // keep these numbers low so our error isn't too high compared to normal cuda // reductions @@ -6415,10 +6503,13 @@ TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, gdimy); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] @@ -6450,7 +6541,7 @@ TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6459,7 +6550,7 @@ TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) { } // Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0 -TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction3dim0_CUDA) { // Grid reductions when there aren't any threads are serial reductions // keep these numbers low so our error isn't too high compared to normal cuda // reductions @@ -6474,10 +6565,13 @@ TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { fusion.addInput(tv0); // tv1[R0, I1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {0}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {0}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(0, gdimy); // tv1[R0o, R0i{128}, I1] = tv0[I0, I1] @@ -6507,7 +6601,7 @@ TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({0}); @@ -6516,7 +6610,7 @@ TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) { } // This is similar to the FusionReduction, but swaps BIDx and TIDx -TEST(NVFuserTest, FusionGridReduction4_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6528,10 +6622,13 @@ TEST(NVFuserTest, FusionGridReduction4_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, gdimx); // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1] @@ -6570,7 +6667,7 @@ TEST(NVFuserTest, FusionGridReduction4_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6580,7 +6677,7 @@ TEST(NVFuserTest, FusionGridReduction4_CUDA) { // Grid reduction with 2D thread blocks but only TIDx and BIDx are // mapped to a reduction dim -TEST(NVFuserTest, FusionGridReduction5_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6593,10 +6690,13 @@ TEST(NVFuserTest, FusionGridReduction5_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{64}] = tv0[I0, I1] @@ -6624,7 +6724,7 @@ TEST(NVFuserTest, FusionGridReduction5_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({1}); @@ -6632,7 +6732,7 @@ TEST(NVFuserTest, FusionGridReduction5_CUDA) { } // Similar to FusionGridReduction1 but with 3D tensors -TEST(NVFuserTest, FusionGridReduction6_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6641,10 +6741,13 @@ TEST(NVFuserTest, FusionGridReduction6_CUDA) { fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); // Splitting for TID tv1->split(2, 128); @@ -6686,7 +6789,7 @@ TEST(NVFuserTest, FusionGridReduction6_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({1, 2}); @@ -6696,7 +6799,7 @@ TEST(NVFuserTest, FusionGridReduction6_CUDA) { } // See issue #1049 -TEST(NVFuserTest, FusionGridReduction7_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6718,7 +6821,7 @@ TEST(NVFuserTest, FusionGridReduction7_CUDA) { at::Tensor cg_output = at::empty({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}); @@ -6726,7 +6829,7 @@ TEST(NVFuserTest, FusionGridReduction7_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReduction8_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction8_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6746,7 +6849,7 @@ TEST(NVFuserTest, FusionGridReduction8_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}); @@ -6754,10 +6857,7 @@ TEST(NVFuserTest, FusionGridReduction8_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReduction9_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridReduction9_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6788,7 +6888,7 @@ TEST(NVFuserTest, FusionGridReduction9_CUDA) { at::ArrayRef aten_inputs = {t0, t2}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_output = fe.runFusion(aten_inputs); auto aten_output = t0.sum({1}).add(t2); @@ -6796,7 +6896,7 @@ TEST(NVFuserTest, FusionGridReduction9_CUDA) { testValidate(&fusion, cg_output, {t0, t2}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReduction10_CUDA) { +TEST_F(NVFuserTest, FusionGridReduction10_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6831,7 +6931,7 @@ TEST(NVFuserTest, FusionGridReduction10_CUDA) { at::Tensor t0 = at::randn({numel_w, numel_x, numel_y, numel_z}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_output = fe.runFusion({t0}); auto aten_output = t0.sum({1, 2, 3}); @@ -6839,7 +6939,7 @@ TEST(NVFuserTest, FusionGridReduction10_CUDA) { testValidate(&fusion, cg_output, {t0}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { +TEST_F(NVFuserTest, FusionNonRedAxisBind_CUDA) { int bid_x = 3; int tid_x = 2; int red_dim = 0; @@ -6851,8 +6951,8 @@ TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); tv1->split(-1, tid_x); @@ -6863,7 +6963,7 @@ TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { at::Tensor input = at::randn({16, bid_x * tid_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = input.to(at::kDouble).sum({red_dim}); @@ -6871,7 +6971,7 @@ TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) { testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSplitBCast_CUDA) { +TEST_F(NVFuserTest, FusionSplitBCast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6881,8 +6981,8 @@ TEST(NVFuserTest, FusionSplitBCast_CUDA) { fusion.addInput(input_tv0); fusion.addInput(input_tv1); - TensorView* sum_tv2 = - reductionOp(BinaryOpType::Add, {2}, new Double(0), input_tv0); + TensorView* sum_tv2 = reductionOp( + BinaryOpType::Add, {2}, IrBuilder::create(0), input_tv0); TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true}); TensorView* output_tv4 = div(input_tv1, bcast_tv3); @@ -6915,11 +7015,11 @@ TEST(NVFuserTest, FusionSplitBCast_CUDA) { at::Tensor cg_output = at::empty({32, 32, 128}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}); fe.runFusion({t0, t1}, {cg_output}); } -TEST(NVFuserTest, FusionBCastInnerDim_CUDA) { +TEST_F(NVFuserTest, FusionBCastInnerDim_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6933,7 +7033,7 @@ TEST(NVFuserTest, FusionBCastInnerDim_CUDA) { TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast()); } -TEST(NVFuserTest, FusionBCastReduce_CUDA) { +TEST_F(NVFuserTest, FusionBCastReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -6949,14 +7049,16 @@ TEST(NVFuserTest, FusionBCastReduce_CUDA) { // Multiple consumer reduction with computeAt // https://github.com/csarofeen/pytorch/issues/110 -TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) { +TEST_F(NVFuserTest, FusionReductionMultiConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); auto tv1 = unaryOp(UnaryOpType::Exp, tv0); - auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Double(0), tv1); - auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Double(0), tv1); + auto tv2 = + reductionOp(BinaryOpType::Max, {-1}, IrBuilder::create(0), tv1); + auto tv3 = + reductionOp(BinaryOpType::Min, {-1}, IrBuilder::create(0), tv1); auto tv4 = add(tv2, tv3); fusion.addOutput(tv4); tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort); @@ -6964,7 +7066,7 @@ TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) { TORCH_CHECK(tv1->getComputeAtPosition() == 2); } -TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { for (const auto i : c10::irange(2)) { Fusion fusion; FusionGuard fg(&fusion); @@ -6973,8 +7075,8 @@ TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = add(tv1, tv2); // Set outputs tv2 or tv1 and then tv3 if (i == 0) { @@ -6996,7 +7098,7 @@ TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { aten_input + 1, (aten_input + 1) * 2}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -7004,7 +7106,7 @@ TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { } } -TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7012,8 +7114,8 @@ TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = add(tv1, tv2); fusion.addOutput(tv3); @@ -7029,14 +7131,14 @@ TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { at::Tensor cg_output = at::empty_like(aten_input, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7045,10 +7147,10 @@ TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { TensorView* tv0 = makeConcreteTensor({dimx, dimy}); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv2, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv2, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = mul(tv2, tv4); fusion.addOutput(tv5); @@ -7065,14 +7167,14 @@ TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { auto aten_output = t2.mul(t4); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { +TEST_F(NVFuserTest, FusionZeroDimComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7080,7 +7182,7 @@ TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { fusion.addInput(tv0); auto tv1 = sum(tv0, {0}); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); TORCH_CHECK(tv2->nDims() == 0); tv1->computeAt(tv2, 0); @@ -7090,14 +7192,14 @@ TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) { auto aten_output = aten_input.to(at::kDouble).sum() + 1; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionZeroDimBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7130,14 +7232,14 @@ TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) { at::Tensor cg_output = at::empty({}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_output}); testValidate( &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroDimReduction_CUDA) { +TEST_F(NVFuserTest, FusionZeroDimReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7166,14 +7268,14 @@ TEST(NVFuserTest, FusionZeroDimReduction_CUDA) { at::Tensor cg_output = at::empty({}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) { +TEST_F(NVFuserTest, FusionBCastAfterReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 128; @@ -7218,14 +7320,14 @@ TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) { std::vector aten_inputs = {t0, t4}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t4}); auto cg_outputs = fe.runFusion({t0, t4}); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionOutputBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionOutputBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7243,15 +7345,14 @@ TEST(NVFuserTest, FusionOutputBroadcast_CUDA) { auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { +TEST_F(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7270,15 +7371,14 @@ TEST(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { aten_input.to(at::kDouble).sum({0, 2, -1}, /*keepdim=*/true); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { +TEST_F(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -7291,7 +7391,11 @@ TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { fusion.addInput(tv0); TensorView* tv1 = reductionOp( - BinaryOpType::Add, {red_dim}, new Double(0), tv0, /*keep_dim=*/true); + BinaryOpType::Add, + {red_dim}, + IrBuilder::create(0), + tv0, + /*keep_dim=*/true); fusion.addOutput(tv1); @@ -7307,11 +7411,10 @@ TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); scheduleReduction(&fusion, reduction_params.value()); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto lparams = reduction_params.value().lparams; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7325,7 +7428,7 @@ TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { lparams); } -TEST(NVFuserTest, FusionSumTo_CUDA) { +TEST_F(NVFuserTest, FusionSumTo_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7340,7 +7443,7 @@ TEST(NVFuserTest, FusionSumTo_CUDA) { sum_to_shape.begin(), sum_to_shape.end(), std::back_inserter(sum_to_symb), - [](int s) -> Int* { return new Int(s); }); + [](int s) -> Int* { return IrBuilder::create(s); }); TensorView* tv0 = makeConcreteTensor(tensor_shape); fusion.addInput(tv0); @@ -7355,8 +7458,7 @@ TEST(NVFuserTest, FusionSumTo_CUDA) { auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); TORCH_CHECK( @@ -7367,7 +7469,7 @@ TEST(NVFuserTest, FusionSumTo_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSumToNoop_CUDA) { +TEST_F(NVFuserTest, FusionSumToNoop_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7382,7 +7484,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { sum_to_shape.begin(), sum_to_shape.end(), std::back_inserter(sum_to_symb), - [](int s) -> Int* { return new Int(s); }); + [](int s) -> Int* { return IrBuilder::create(s); }); TensorView* tv0 = makeConcreteTensor(tensor_shape); fusion.addInput(tv0); @@ -7390,7 +7492,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { TensorView* tv1 = sum_to(tv0, sum_to_symb); // Dummy operator to avoid tv0 both input and output - TensorView* tv2 = add(tv1, new Double(0)); + TensorView* tv2 = add(tv1, IrBuilder::create(0)); fusion.addOutput(tv2); const auto options = @@ -7399,8 +7501,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { at::Tensor aten_input = at::randn(tensor_shape_ref, options); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); @@ -7412,7 +7513,7 @@ TEST(NVFuserTest, FusionSumToNoop_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReductionScheduler_CUDA) { +TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -7424,8 +7525,8 @@ TEST(NVFuserTest, FusionReductionScheduler_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -7442,7 +7543,7 @@ TEST(NVFuserTest, FusionReductionScheduler_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); @@ -7458,7 +7559,7 @@ TEST(NVFuserTest, FusionReductionScheduler_CUDA) { } // Simple reduction parallelized on a symbolic size. -TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { +TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7467,7 +7568,8 @@ TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); // Interface should just be a direct split with a Parallel type. We can @@ -7501,7 +7603,7 @@ TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7515,7 +7617,7 @@ TEST(NVFuserTest, FusionSymbolicReduction_CUDA) { lparams); } -TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { const std::vector red_dims = {0, 2}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions @@ -7530,8 +7632,8 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -7547,7 +7649,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); fe.runFusion({aten_input}, {cg_output}, lparams); testValidate( @@ -7561,7 +7663,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { lparams); } -TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { const std::vector red_dims = {1, 3}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions @@ -7575,8 +7677,8 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -7590,7 +7692,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7604,7 +7706,7 @@ TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { lparams); } -TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { std::vector dtypes = { DataType::Double, DataType::Float, DataType::Half}; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 @@ -7661,8 +7763,7 @@ TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -7678,7 +7779,7 @@ TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { } } -TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { +TEST_F(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { std::vector dtypes = { DataType::Double, DataType::Float, DataType::Half}; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 @@ -7740,8 +7841,7 @@ TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({axis}); testValidate( @@ -7759,14 +7859,14 @@ TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { } } -TEST(NVFuserTest, FusionCacheBefore_CUDA) { +TEST_F(NVFuserTest, FusionCacheBefore_CUDA) { // TVM Cache Write Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, new Double(1.0)); - TensorView* tv2 = mul(tv1, new Double(3.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -7790,21 +7890,21 @@ TEST(NVFuserTest, FusionCacheBefore_CUDA) { at::Tensor aten_output = (aten_input + 1.0) * 3.0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheAfter_CUDA) { +TEST_F(NVFuserTest, FusionCacheAfter_CUDA) { // TVM Cache Read Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, new Double(1.0)); - TensorView* tv2 = mul(tv1, new Double(3.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -7828,20 +7928,20 @@ TEST(NVFuserTest, FusionCacheAfter_CUDA) { at::Tensor aten_output = (aten_input + 1.0) * 3.0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheFork_CUDA) { +TEST_F(NVFuserTest, FusionCacheFork_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, new Double(1.0)); - TensorView* tv2 = mul(tv1, new Double(3.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); fusion.addInput(tv0); fusion.addOutput(tv1); fusion.addOutput(tv2); @@ -7873,7 +7973,7 @@ TEST(NVFuserTest, FusionCacheFork_CUDA) { at::Tensor aten_output2 = aten_output1 * 3.0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -7885,7 +7985,7 @@ TEST(NVFuserTest, FusionCacheFork_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionCacheIndirect_CUDA) { +TEST_F(NVFuserTest, FusionCacheIndirect_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7927,14 +8027,14 @@ TEST(NVFuserTest, FusionCacheIndirect_CUDA) { at::Tensor aten_output = (t1 + (t2 - t3)) - t0; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheBcast_CUDA) { +TEST_F(NVFuserTest, FusionCacheBcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -7986,22 +8086,22 @@ TEST(NVFuserTest, FusionCacheBcast_CUDA) { t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0)); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { +TEST_F(NVFuserTest, FusionCacheMultiConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv0, new Double(1)); - TensorView* tv4 = add(tv3, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(1)); + TensorView* tv4 = add(tv3, IrBuilder::create(2)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -8025,7 +8125,7 @@ TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { auto aten_output = (aten_input + 1) + 2; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -8037,7 +8137,7 @@ TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionSmem_CUDA) { +TEST_F(NVFuserTest, FusionSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8084,7 +8184,7 @@ TEST(NVFuserTest, FusionSmem_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); testValidate( @@ -8093,7 +8193,7 @@ TEST(NVFuserTest, FusionSmem_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemReduce_CUDA) { +TEST_F(NVFuserTest, FusionSmemReduce_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8133,7 +8233,7 @@ TEST(NVFuserTest, FusionSmemReduce_CUDA) { at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1}); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( @@ -8141,7 +8241,7 @@ TEST(NVFuserTest, FusionSmemReduce_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { +TEST_F(NVFuserTest, FusionSmemBlockGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8202,7 +8302,7 @@ TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1}); auto cg_outputs = fe.runFusion({t0, t1}); testValidate( @@ -8211,7 +8311,7 @@ TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { +TEST_F(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8291,7 +8391,7 @@ TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -8300,7 +8400,7 @@ TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8309,7 +8409,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { TensorView* max_val = reductionOp( BinaryOpType::Max, {-1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), x); // (M) TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) TensorView* x_max_sub = sub(x, bcast_max); // (M, N) @@ -8336,7 +8436,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { bcast_sum, softmax}); - auto tidx = new Int(); + auto tidx = IrBuilder::create(); fusion.addInput(tidx); for (auto tensor : all_tensors) { @@ -8363,7 +8463,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input, 128}); auto cg_outputs = fe.runFusion({aten_input, 128}); testValidate( @@ -8375,7 +8475,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { +TEST_F(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8401,7 +8501,7 @@ TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { auto lparams = reduction_params.value().lparams; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -8415,7 +8515,7 @@ TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { lparams); } -TEST(NVFuserTest, TestMaskSoftmax_CUDA) { +TEST_F(NVFuserTest, FusionTestMaskSoftmax_CUDA) { // This test is testing the usage of all padding tokens // with softmax like Bert might might use in a full padding // sequence. @@ -8456,7 +8556,7 @@ TEST(NVFuserTest, TestMaskSoftmax_CUDA) { auto lparams = reduction_params.value().lparams; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input, aten_mask}, lparams); auto cg_outputs = fe.runFusion({aten_input, aten_mask}, lparams); testValidate( @@ -8470,7 +8570,7 @@ TEST(NVFuserTest, TestMaskSoftmax_CUDA) { lparams); } -TEST(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { +TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -8558,13 +8658,13 @@ TEST(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { +TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); const float kEps = 1e-5; - Double* eps_ptr = new Double(kEps); + Double* eps_ptr = IrBuilder::create(kEps); std::vector input_shape{20, 100, 35, 67}; std::vector norm_shape{67}; @@ -8594,7 +8694,7 @@ TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { auto lparams = reduction_params.value().lparams; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -8610,8 +8710,9 @@ TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { lparams); } -TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } auto fusion = std::make_unique(); @@ -8633,8 +8734,8 @@ TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { fusion->addInput(running_mean); fusion->addInput(running_var); - Double* momentum = new Double(kMomentum); - Double* eps = new Double(kEps); + Double* momentum = IrBuilder::create(kMomentum); + Double* eps = IrBuilder::create(kEps); auto result = batch_norm( input, weight, bias, running_mean, running_var, kTraining, momentum, eps); @@ -8681,7 +8782,7 @@ TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { ""); } -TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { +TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8697,12 +8798,12 @@ TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { TensorView* max_sx = reductionOp( BinaryOpType::Max, {-1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), sx); // (M) TensorView* max_dx = reductionOp( BinaryOpType::Max, {-1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), dx); // (M) // Reduction => merge local and shared memory TensorViews @@ -8804,7 +8905,7 @@ TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { aten_output.narrow(1, static_size, dimy - static_size); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_static_in, aten_dynamic_in}); fe.runFusion( {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out}); @@ -8817,7 +8918,7 @@ TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { +TEST_F(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -8830,10 +8931,10 @@ TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { fusion.addInput(sx); fusion.addInput(dx); - Double* gamma = new Double(); - Double* beta = new Double(); - Double* eps = new Double(); - Int* N = new Int(); + Double* gamma = IrBuilder::create(); + Double* beta = IrBuilder::create(); + Double* eps = IrBuilder::create(); + Int* N = IrBuilder::create(); fusion.addInput(gamma); fusion.addInput(beta); fusion.addInput(eps); @@ -8982,7 +9083,7 @@ TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out}); auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); @@ -9003,16 +9104,16 @@ TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views auto x = makeSymbolicTensor(2); - Double* gamma = new Double(); - Double* beta = new Double(); - Double* eps = new Double(); - Int* N = new Int(); + Double* gamma = IrBuilder::create(); + Double* beta = IrBuilder::create(); + Double* eps = IrBuilder::create(); + Int* N = IrBuilder::create(); fusion.addInput(x); fusion.addInput(gamma); fusion.addInput(beta); @@ -9062,7 +9163,7 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { norm_gamma, norm_gamma_beta}); - auto tidx = new Int(); + auto tidx = IrBuilder::create(); fusion.addInput(tidx); for (auto tensor : all_tensors) { @@ -9105,20 +9206,21 @@ TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { aten_input, kGamma, kBeta, kEps, dimy, TIDX}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addInput(tv0); fusion.addOutput(tv1); // tv1[I0, R1] = tv0[I0, I1] @@ -9150,7 +9252,7 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); auto cg_outputs = fe.runFusion({aten_input}, lparams); testValidate( @@ -9165,12 +9267,12 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); } -TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Algorithm - Int* sym_bsx = new Int(); + Int* sym_bsx = IrBuilder::create(); TensorView* tv0 = makeSymbolicTensor(3); // M, K, N fusion.addInput(tv0); fusion.addInput(sym_bsx); @@ -9213,7 +9315,7 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input, runtime_threadIdx_dim}, lparams); auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams); testValidate( @@ -9229,11 +9331,11 @@ TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Int* sym_bsx = new Int(); + Int* sym_bsx = IrBuilder::create(); TensorView* tv0 = makeSymbolicTensor(2); // (M, K) TensorView* tv1 = makeSymbolicTensor(2); // (K, N) TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) @@ -9278,7 +9380,7 @@ TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { LaunchParams lparams(-1, -1, -1, BSX, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( @@ -9294,14 +9396,16 @@ TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { +TEST_F(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = new Int(); // bound to threadIdx.z - Int* symbolic_split_k_tile_dim = new Int(); // bound to blockIdx.x - Int* symbolic_block_k_tile_dim = new Int(); // bound to threadIdx.x + Int* symbolic_m_tile_dim = IrBuilder::create(); // bound to threadIdx.z + Int* symbolic_split_k_tile_dim = + IrBuilder::create(); // bound to blockIdx.x + Int* symbolic_block_k_tile_dim = + IrBuilder::create(); // bound to threadIdx.x // Compile-time integer for tiling int n_smem_tile = 8; // bound to threadIdx.y @@ -9397,10 +9501,6 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - FusionExecutor fe; - // Generate CUDA and compile with nvRTC - fe.compileFusion(&fusion); - // Runtime tiling int m_tile = 4; // bound to threadIdx.z int split_k = 7; // bound to blockIdx.x @@ -9410,6 +9510,9 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); + FusionExecutor fe; + // Generate CUDA and compile with nvRTC + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -9418,13 +9521,14 @@ TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); } -TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { +TEST_F(NVFuserTest, FusionGlobalIntermediate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); fusion.addInput(tv0); fusion.addOutput(tv1); // tv1[I0, R1] = tv0[I0, I1] @@ -9455,7 +9559,7 @@ TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}, lparams); auto cg_outputs = fe.runFusion({input}, lparams); auto aten_output = input.to(at::kDouble).sum({1}); @@ -9470,7 +9574,7 @@ TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) { lparams); } -TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { +TEST_F(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9504,18 +9608,18 @@ TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0, t1, t2, t3}); auto cg_outputs = fe.runFusion({t0, t1, t2, t3}); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConstCheck_CUDA) { +TEST_F(NVFuserTest, FusionConstCheck_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - auto one = new Int(1); + auto one = IrBuilder::create(1); TORCH_CHECK(one->isConstScalar()); auto one_x2 = mul(one, one); @@ -9528,7 +9632,7 @@ TEST(NVFuserTest, FusionConstCheck_CUDA) { TORCH_CHECK(one_x4->isConstScalar()); } -TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { +TEST_F(NVFuserTest, FusionUnrollWithAlloc_CUDA) { const std::vector tensor_dims_in = {128, 128}; Fusion fusion; FusionGuard fg(&fusion); @@ -9537,8 +9641,9 @@ TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(0)); - TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1); + TensorView* tv1 = add(tv0, IrBuilder::create(0)); + TensorView* tv2 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); fusion.addOutput(tv2); const auto options = @@ -9562,7 +9667,7 @@ TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { tv1->computeAt(tv2_rf, -1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto aten_output = (input + 0).to(at::kDouble).sum(1); @@ -9571,12 +9676,12 @@ TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) { } // Test isZeroInt -TEST(NVFuserTest, FusionIsZeroInt_CUDA) { +TEST_F(NVFuserTest, FusionIsZeroInt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Int* x = new Int(0); - Int* y = new Int(1); + Int* x = IrBuilder::create(0); + Int* y = IrBuilder::create(1); Val* z = mul(x, y); TORCH_CHECK(x->isZeroInt()); TORCH_CHECK(!y->isZeroInt()); @@ -9584,12 +9689,12 @@ TEST(NVFuserTest, FusionIsZeroInt_CUDA) { } // Test isOneInt -TEST(NVFuserTest, FusionIsOneInt_CUDA) { +TEST_F(NVFuserTest, FusionIsOneInt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - Int* x = new Int(1); - Int* y = new Int(1); + Int* x = IrBuilder::create(1); + Int* y = IrBuilder::create(1); Val* z = mul(x, y); TORCH_CHECK(x->isOneInt()); TORCH_CHECK(y->isOneInt()); @@ -9599,7 +9704,7 @@ TEST(NVFuserTest, FusionIsOneInt_CUDA) { // This is to verify no cycle of computeAt is created. A more complex // variation of this pattern appears in one of the Python tests // (test_random_topo). -TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9607,12 +9712,12 @@ TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { fusion.addInput(tv0); // Common intermediate tensor - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); // tv1 -> tv2 - auto tv2 = add(tv1, new Double(2)); + auto tv2 = add(tv1, IrBuilder::create(2)); // tv1 -> tv3 -> tv4 - auto tv3 = add(tv1, new Double(3)); - auto tv4 = add(tv3, new Double(4)); + auto tv3 = add(tv1, IrBuilder::create(3)); + auto tv4 = add(tv3, IrBuilder::create(4)); // NOTE: This should no longer occur as of PR #201. // The order of adding outputs matters. If tv3 is added before tv4, @@ -9639,7 +9744,7 @@ TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { auto t4 = t3 + 4; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); std::vector aten_outputs = {t2, t4, t3}; @@ -9647,7 +9752,7 @@ TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9655,10 +9760,10 @@ TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv0, new Double(2)); - TensorView* tv3 = add(tv1, new Double(3)); - TensorView* tv4 = add(tv1, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); + TensorView* tv4 = add(tv1, IrBuilder::create(4)); fusion.addOutput(tv2); fusion.addOutput(tv3); @@ -9666,9 +9771,6 @@ TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { tv1->computeAt(tv3, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({10, 10}, options); @@ -9684,12 +9786,14 @@ TEST(NVFuserTest, FusionTraversalOrder1_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -9697,11 +9801,11 @@ TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv1, tv3); @@ -9712,9 +9816,6 @@ TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { tv1->computeAt(tv5, -1); tv3->computeAt(tv5, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({10, 10}, options); @@ -9731,13 +9832,15 @@ TEST(NVFuserTest, FusionTraversalOrder2_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder3_CUDA) { for (const auto i : c10::irange(2)) { Fusion fusion; FusionGuard fg(&fusion); @@ -9745,11 +9848,11 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv1, tv3); @@ -9774,9 +9877,6 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { compute_at_outer->computeAt(tv5, -2); compute_at_inner->computeAt(tv5, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); auto t1 = aten_input + 1; @@ -9792,6 +9892,8 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( @@ -9799,25 +9901,25 @@ TEST(NVFuserTest, FusionTraversalOrder3_CUDA) { } } -TEST(NVFuserTest, FusionTraversalOrder4_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // First tree TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv1, new Double(3)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); fusion.addOutput(tv2); fusion.addOutput(tv3); // Second tree TensorView* tv4 = makeSymbolicTensor(1); fusion.addInput(tv4); - TensorView* tv5 = add(tv4, new Double(5)); - TensorView* tv6 = add(tv5, new Double(6)); - TensorView* tv7 = add(tv5, new Double(7)); + TensorView* tv5 = add(tv4, IrBuilder::create(5)); + TensorView* tv6 = add(tv5, IrBuilder::create(6)); + TensorView* tv7 = add(tv5, IrBuilder::create(7)); fusion.addOutput(tv6); fusion.addOutput(tv7); @@ -9844,23 +9946,23 @@ TEST(NVFuserTest, FusionTraversalOrder4_CUDA) { at::empty_like(t0, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); fe.runFusion(aten_inputs, cg_outputs); testValidate( &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv2, tv4); fusion.addOutput(tv1); @@ -9870,9 +9972,6 @@ TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { tv2->computeAt(tv5, -1); tv4->computeAt(tv5, -1); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); std::vector cg_outputs = { @@ -9880,6 +9979,8 @@ TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); auto t1 = aten_input + 1; @@ -9894,16 +9995,16 @@ TEST(NVFuserTest, FusionTraversalOrder5_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv0, new Double(2)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(2)); TensorView* tv3 = add(tv1, tv2); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); fusion.addOutput(tv4); @@ -9916,9 +10017,6 @@ TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { tv1->computeAt(tv3, -1); tv2->computeAt(tv3, -2); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); @@ -9929,22 +10027,24 @@ TEST(NVFuserTest, FusionTraversalOrder6_CUDA) { at::Tensor cg_output = at::empty_like(aten_input, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { +TEST_F(NVFuserTest, FusionTraversalOrder7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(2)); - TensorView* tv3 = add(tv0, new Double(3)); - TensorView* tv4 = add(tv3, new Double(4)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); TensorView* tv5 = add(tv2, tv4); fusion.addOutput(tv5); @@ -9963,9 +10063,6 @@ TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { tv2->computeAt(tv5, -4); tv4->computeAt(tv5, -3); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({100}, options); @@ -9976,6 +10073,9 @@ TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { auto aten_output = t2 + t4; at::Tensor cg_output = at::empty_like(aten_input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( @@ -9983,7 +10083,7 @@ TEST(NVFuserTest, FusionTraversalOrder7_CUDA) { } // Test predication of grid reduction -TEST(NVFuserTest, FusionThreadPredicate_CUDA) { +TEST_F(NVFuserTest, FusionThreadPredicate_CUDA) { const int gdimx = 4; const int bdimx = 128; @@ -9993,9 +10093,10 @@ TEST(NVFuserTest, FusionThreadPredicate_CUDA) { TensorView* tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1); - TensorView* tv3 = add(tv0, new Double(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(2)); fusion.addOutput(tv3); fusion.addOutput(tv2); @@ -10036,14 +10137,14 @@ TEST(NVFuserTest, FusionThreadPredicate_CUDA) { at::empty_like(aten_input, options), at::empty({numel_x}, options)}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, cg_outputs); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionLSTMCell_CUDA) { +TEST_F(NVFuserTest, FusionLSTMCell_CUDA) { const int hidden_features = 512; const int batch_size = 64; @@ -10116,14 +10217,14 @@ TEST(NVFuserTest, FusionLSTMCell_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { +TEST_F(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10131,7 +10232,7 @@ TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { TensorView* tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - TensorView* tv1 = mul(tv0, new Double(0.5)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); TensorView* tv2 = broadcast(tv1, {true, false}); TensorView* tv3 = broadcast(tv1, {false, true}); TensorView* tv4 = add(tv2, tv3); @@ -10142,7 +10243,7 @@ TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { ASSERT_ANY_THROW(tv1->computeAt(tv3, -1)); } -TEST(NVFuserTest, FusionReductionHalf_CUDA) { +TEST_F(NVFuserTest, FusionReductionHalf_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10151,7 +10252,7 @@ TEST(NVFuserTest, FusionReductionHalf_CUDA) { fusion.addInput(tv0); auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = add(tv1, new Double(1.0)); + auto tv2 = add(tv1, IrBuilder::create(1.0)); auto tv3 = sum(tv2, {2}); auto tv4 = castOp(DataType::Half, tv3); @@ -10172,7 +10273,7 @@ TEST(NVFuserTest, FusionReductionHalf_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); @@ -10189,7 +10290,7 @@ TEST(NVFuserTest, FusionReductionHalf_CUDA) { lparams); } -TEST(NVFuserTest, FusionReduceSingle_CUDA) { +TEST_F(NVFuserTest, FusionReduceSingle_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10205,7 +10306,7 @@ TEST(NVFuserTest, FusionReduceSingle_CUDA) { // Grab only tensor views, though there shouldn't be any other type FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}); @@ -10214,7 +10315,7 @@ TEST(NVFuserTest, FusionReduceSingle_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -10226,8 +10327,8 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim, 2}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim, 2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); const auto options = @@ -10241,7 +10342,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({red_dim, 2}); @@ -10257,7 +10358,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { lparams); } -TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -10269,10 +10370,11 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); - TensorView* tv2 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv1); + TensorView* tv2 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv1); fusion.addOutput(tv2); const auto options = @@ -10287,7 +10389,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({1, 2}); @@ -10303,7 +10405,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { lparams); } -TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; @@ -10315,10 +10417,11 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0); + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); - TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1); + TensorView* tv2 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); fusion.addOutput(tv2); const auto options = @@ -10332,7 +10435,7 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}, lparams); // no broadcasting needed, omitting the last optional argument; auto cg_outputs = fe.runFusion({aten_input}, lparams); auto aten_output = aten_input.to(at::kDouble).sum({2, 1}); @@ -10348,24 +10451,27 @@ TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { lparams); } -TEST(NVFuserTest, FusionTrivialReduction_CUDA) { +TEST_F(NVFuserTest, FusionTrivialReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeConcreteTensor({10, 20, 1}); fusion.addInput(tv0); - TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); fusion.addOutput(tv1); - TORCH_CHECK(!fusion.hasReduction(), "Trivial reduction picked up by fusion"); + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).empty(), + "Trivial reduction picked up by fusion"); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({10, 20, 1}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); auto aten_output = aten_input.to(at::kDouble).sum({2}); @@ -10373,7 +10479,7 @@ TEST(NVFuserTest, FusionTrivialReduction_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTrivialReduction2_CUDA) { +TEST_F(NVFuserTest, FusionTrivialReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10400,14 +10506,14 @@ TEST(NVFuserTest, FusionTrivialReduction2_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTrivialReduction3_CUDA) { +TEST_F(NVFuserTest, FusionTrivialReduction3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10433,7 +10539,7 @@ TEST(NVFuserTest, FusionTrivialReduction3_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( @@ -10442,7 +10548,7 @@ TEST(NVFuserTest, FusionTrivialReduction3_CUDA) { // Make sure trivial reductions are correctly detected even with // scheduling applied. -TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { +TEST_F(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10459,8 +10565,8 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { auto tv4 = tv2->rFactor({-1}); auto tv5 = broadcast(tv0, {true, false}); - auto tv6 = add(tv5, new Double(1)); - auto tv7 = sub(tv6, new Double(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + auto tv7 = sub(tv6, IrBuilder::create(1)); auto tv8 = sum(tv7, {0}); fusion.addOutput(tv8); @@ -10483,10 +10589,10 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { GpuLower gpulw(&fusion); - // No kir::ReductionOp should be generated as all the reduction + // No ReductionOp should be generated as all the reduction // exprs should be replaced with a unary set op. - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - TORCH_CHECK(!kir_node->isA()); + for (const auto expr : gpulw.kernel()->as()->exprs()) { + TORCH_CHECK(!expr->isA()); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -10494,7 +10600,7 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -10502,14 +10608,14 @@ TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { } // Test detection of partially trivial reduction -TEST(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { +TEST_F(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(1, 1); @@ -10525,17 +10631,17 @@ TEST(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { GpuLower gpulw(&fusion); // tv3's reduction axis is a trivial reduction. The only - // kir::ReductionOp should be for tv1. - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (kir_node->isA()) { + // ReductionOp should be for tv1. + for (const auto expr : gpulw.kernel()->as()->exprs()) { + if (expr->isA()) { auto reduction_out = - kir_node->as()->outputs()[0]->as(); - TORCH_CHECK(reduction_out->fuserTv() == tv1); + expr->as()->outputs()[0]->as(); + TORCH_CHECK(reduction_out->name() == 1); } } } -TEST(NVFuserTest, FusionInputsIdLookup_CUDA) { +TEST_F(NVFuserTest, FusionInputsIdLookup_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({16, 8, 8}, options); at::Tensor t1 = at::randn({8, 8}, options); @@ -10573,7 +10679,7 @@ TEST(NVFuserTest, FusionInputsIdLookup_CUDA) { TORCH_CHECK(id_1_relook.eviction == false); } -TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { std::vector sizes_vec({16, 8, 8}); std::vector strides_vec({64, 8, 1}); auto tensor_type = TensorType::create( @@ -10610,7 +10716,7 @@ TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { TORCH_CHECK(complyWith(t6, TensorType::create(t6))); } -TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { std::vector sizes_vec({16, 1, 8}); std::vector strides_vec({8, 8, 1}); auto tensor_type = TensorType::create( @@ -10634,7 +10740,7 @@ TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { TORCH_CHECK(complyWith(t3, tensor_type)); } -TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { std::vector sizes_vec({16, 8, 8}); std::vector strides_vec({64, 1, 8}); auto tensor_type = TensorType::create( @@ -10650,7 +10756,7 @@ TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { TORCH_CHECK(complyWith(t1, tensor_type)); } -TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { +TEST_F(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { std::vector sizes_vec({16, 8, 8}); std::vector strides_vec({128, 16, 1}); auto tensor_type = TensorType::create( @@ -10666,7 +10772,7 @@ TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { TORCH_CHECK(complyWith(t1, tensor_type)); } -TEST(NVFuserTest, FusionDisjointSet_CUDA) { +TEST_F(NVFuserTest, FusionDisjointSet_CUDA) { DisjointSet set; const std::set group_x({0, 1, 2}); @@ -10779,7 +10885,7 @@ TEST(NVFuserTest, FusionDisjointSet_CUDA) { } } -TEST(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { +TEST_F(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10802,7 +10908,7 @@ TEST(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { ASSERT_ANY_THROW(tv3->computeAt(tv4, -1)); } -TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) { +TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10819,14 +10925,14 @@ TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto t3 = castOp(DataType::Float, t2); auto t4 = broadcast(t1, {true, true, false}); auto t5 = add(t4, t3); - auto t6 = mul(t5, new Double(0.5)); - auto t7 = mul(t5, new Double(k_079)); - auto t8 = mul(t5, new Double(k_004)); + auto t6 = mul(t5, IrBuilder::create(0.5)); + auto t7 = mul(t5, IrBuilder::create(k_079)); + auto t8 = mul(t5, IrBuilder::create(k_004)); auto t9 = mul(t8, t5); - auto t10 = add(t9, new Int(1)); + auto t10 = add(t9, IrBuilder::create(1)); auto t11 = mul(t7, t10); auto t12 = unaryOp(UnaryOpType::Tanh, t11); - auto t13 = add(t12, new Double(1)); + auto t13 = add(t12, IrBuilder::create(1)); auto t14 = mul(t6, t13); auto t15 = castOp(DataType::Half, t14); fusion.addOutput(t15); @@ -10849,15 +10955,14 @@ TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { +TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { if (at::cuda::getDeviceProperties(0)->major < 6) { return; } @@ -10882,23 +10987,23 @@ TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { auto t5 = castOp(DataType::Float, t4); auto t6 = broadcast(t3, {true, true, false}); auto t7 = add(t6, t5); - auto t8 = mul(t7, new Double(k_079)); - auto t9 = mul(t7, new Double(k_004)); + auto t8 = mul(t7, IrBuilder::create(k_079)); + auto t9 = mul(t7, IrBuilder::create(k_004)); auto t10 = mul(t9, t7); - auto t11 = add(t10, new Int(1)); + auto t11 = add(t10, IrBuilder::create(1)); auto t12 = mul(t8, t11); auto t13 = unaryOp(UnaryOpType::Tanh, t12); - auto t14 = mul(t7, new Double(0.5)); + auto t14 = mul(t7, IrBuilder::create(0.5)); auto t15 = mul(t13, t13); auto t16 = unaryOp(UnaryOpType::Neg, t15); - auto t17 = add(t16, new Int(1)); - auto t18 = mul(t7, new Double(k_010)); + auto t17 = add(t16, IrBuilder::create(1)); + auto t18 = mul(t7, IrBuilder::create(k_010)); auto t19 = mul(t18, t7); - auto t20 = add(t19, new Double(k_079)); + auto t20 = add(t19, IrBuilder::create(k_079)); auto t21 = mul(t17, t20); auto t22 = mul(t14, t21); - auto t23 = add(t13, new Int(1)); - auto t24 = mul(t23, new Double(0.5)); + auto t23 = add(t13, IrBuilder::create(1)); + auto t24 = mul(t23, IrBuilder::create(0.5)); auto t25 = add(t22, t24); auto t26 = mul(t25, t1); // Save float output for validation @@ -10929,8 +11034,7 @@ TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs, lparams); auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( @@ -10938,7 +11042,7 @@ TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) { } // Reproducer of issue #459 -TEST(NVFuserTest, FusionIssue459_CUDA) { +TEST_F(NVFuserTest, FusionIssue459_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -10947,14 +11051,14 @@ TEST(NVFuserTest, FusionIssue459_CUDA) { auto tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = add(tv1, tv3); // Create two outputs from the final arithmetic result - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); - auto tv6 = add(tv4, new Double(1)); + auto tv6 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv6); // Scheduling @@ -10981,8 +11085,7 @@ TEST(NVFuserTest, FusionIssue459_CUDA) { std::vector aten_inputs = {t0, t1}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -10994,15 +11097,15 @@ TEST(NVFuserTest, FusionIssue459_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionSmemIndexingSimple_CUDA) { +TEST_F(NVFuserTest, FusionSmemIndexingSimple_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); tv3->axis(0)->parallelize(ParallelType::BIDx); @@ -11013,28 +11116,27 @@ TEST(NVFuserTest, FusionSmemIndexingSimple_CUDA) { tv1->setMemoryType(MemoryType::Shared); tv2->setMemoryType(MemoryType::Global); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto aten_input = at::randn({12, 34}, options); at::Tensor aten_output = aten_input + 1.0 + 1.0 + 1.0; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSmemIndexing_CUDA) { +TEST_F(NVFuserTest, FusionSmemIndexing_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = new Int(); - Int* symbolic_split_k_tile_dim = new Int(); - Int* symbolic_block_k_tile_dim = new Int(); + Int* symbolic_m_tile_dim = IrBuilder::create(); + Int* symbolic_split_k_tile_dim = IrBuilder::create(); + Int* symbolic_block_k_tile_dim = IrBuilder::create(); // Compile-time integer for tiling int n_smem_tile = 32; @@ -11131,9 +11233,8 @@ TEST(NVFuserTest, FusionSmemIndexing_CUDA) { // A, B, m_tile_dim, split_k, intra_cta_tile std::vector aten_inputs = {t0, t1, 3, 4, 5}; - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( @@ -11141,13 +11242,13 @@ TEST(NVFuserTest, FusionSmemIndexing_CUDA) { } // Reproducer of issue 408 -TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) { +TEST_F(NVFuserTest, FusionCacheBeforeReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); fusion.addOutput(tv2); @@ -11160,9 +11261,6 @@ TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 100; const int numel_y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -11172,21 +11270,23 @@ TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) { auto aten_output = (aten_input + 1).to(at::kDouble).sum({1}); + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); fe.runFusion({aten_input}, {cg_output}); testValidate( &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { +TEST_F(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(3); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv2); fusion.addOutput(tv3); @@ -11201,9 +11301,6 @@ TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 10; const int numel_y = 20; const int numel_z = 30; @@ -11214,20 +11311,22 @@ TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { auto t3 = t2 + 1; std::vector aten_outputs = {t2, t3}; + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue367_CUDA) { +TEST_F(NVFuserTest, FusionIssue367_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = new Int(); - Int* symbolic_split_k_tile_dim = new Int(); - Int* symbolic_block_k_tile_dim = new Int(); + Int* symbolic_m_tile_dim = IrBuilder::create(); + Int* symbolic_split_k_tile_dim = IrBuilder::create(); + Int* symbolic_block_k_tile_dim = IrBuilder::create(); // Compile-time integer for tiling int n_smem_tile = 32; @@ -11320,14 +11419,14 @@ TEST(NVFuserTest, FusionIssue367_CUDA) { mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue468_CUDA) { +TEST_F(NVFuserTest, FusionIssue468_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -11346,15 +11445,15 @@ TEST(NVFuserTest, FusionIssue468_CUDA) { at::Tensor aten_input = at::randn({10, 100}, options); at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}).sum({0}); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue363_CUDA) { +TEST_F(NVFuserTest, FusionIssue363_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -11402,21 +11501,21 @@ TEST(NVFuserTest, FusionIssue363_CUDA) { std::vector aten_inputs = {t0, t1}; torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue484_CUDA) { +TEST_F(NVFuserTest, FusionIssue484_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, new Double(0)); + auto tv2 = add(tv1, IrBuilder::create(0)); fusion.addOutput(tv2); tv1->setMemoryType(MemoryType::Global); @@ -11430,20 +11529,20 @@ TEST(NVFuserTest, FusionIssue484_CUDA) { at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}); torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue329_CUDA) { +TEST_F(NVFuserTest, FusionIssue329_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); fusion.addOutput(tv2); auto tv3 = sum(tv1, {1}); @@ -11460,22 +11559,21 @@ TEST(NVFuserTest, FusionIssue329_CUDA) { std::vector aten_outputs = {t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue382_CUDA) { +TEST_F(NVFuserTest, FusionIssue382_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = broadcast(tv1, {false, false, true}); auto tv3 = makeSymbolicTensor(3); fusion.addInput(tv3); @@ -11492,9 +11590,6 @@ TEST(NVFuserTest, FusionIssue382_CUDA) { tv1->setMemoryType(MemoryType::Global); tv2->setMemoryType(MemoryType::Global); - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 12; const int numel_y = 34; const int numel_z = 56; @@ -11507,20 +11602,22 @@ TEST(NVFuserTest, FusionIssue382_CUDA) { std::vector aten_inputs = {t0, t3}; auto aten_output = (t0 + 1).unsqueeze(-1) + t3; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); testValidate( &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue507_CUDA) { +TEST_F(NVFuserTest, FusionIssue507_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->setMemoryType(MemoryType::Shared); @@ -11538,22 +11635,21 @@ TEST(NVFuserTest, FusionIssue507_CUDA) { auto aten_output = (t1 + 1); FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); testValidate( &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue532_CUDA) { +TEST_F(NVFuserTest, FusionIssue532_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(1)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(1)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -11579,7 +11675,7 @@ TEST(NVFuserTest, FusionIssue532_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0 + 1 + 1; @@ -11588,14 +11684,14 @@ TEST(NVFuserTest, FusionIssue532_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionLoopUnswitch_CUDA) { +TEST_F(NVFuserTest, FusionLoopUnswitch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, new Double(1)); - TensorView* tv2 = add(tv1, new Double(1)); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(1)); fusion.addInput(tv0); fusion.addOutput(tv2); @@ -11612,7 +11708,7 @@ TEST(NVFuserTest, FusionLoopUnswitch_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0 + 1 + 1; @@ -11621,7 +11717,7 @@ TEST(NVFuserTest, FusionLoopUnswitch_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue549_CUDA) { +TEST_F(NVFuserTest, FusionIssue549_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -11631,7 +11727,7 @@ TEST(NVFuserTest, FusionIssue549_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); TensorView* tv3 = broadcast(tv2, {false, false, true}); // tv3[I0, I1, B] = tv0[I0, I1] @@ -11689,10 +11785,12 @@ TEST(NVFuserTest, FusionIssue549_CUDA) { at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); - FusionExecutor fe; - fe.compileFusion(&fusion); // Lets specify a few bounds in launch params to make sure it works - fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); + LaunchParams lparams(1, -1, -1, 32, 4, 4); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + fe.runFusion({t0, t1}, lparams); // Make sure bad launch params throws // TODO: Re-enable once we have parallelization validation in. @@ -11707,7 +11805,7 @@ TEST(NVFuserTest, FusionIssue549_CUDA) { &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, simplecompileRtc_CUDA) { +TEST_F(NVFuserTest, FusionSimpleCompileRtc_CUDA) { FusionExecutor fe; std::string kernel = R"( __global__ void kernel1(Tensor T0, Tensor T1) { @@ -11739,7 +11837,7 @@ __global__ void kernel1(Tensor T0, Tensor T1) { TORCH_CHECK(out_ref.allclose(out0)); } -TEST(NVFuserTest, FusionSerialWelford_CUDA) { +TEST_F(NVFuserTest, FusionSerialWelford_CUDA) { FusionExecutor fe; int x = 128, y = 64, z = 64; @@ -11796,7 +11894,7 @@ __global__ void kernel1( TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); } -TEST(NVFuserTest, FusionBlockWelford_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelford_CUDA) { FusionExecutor fe; int x = 7, y = 8, z = 9; @@ -11884,7 +11982,7 @@ __global__ void kernel1( cat_tensor.mean({1}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); } -TEST(NVFuserTest, FusionBlockWelfordNoInit_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelfordNoInit_CUDA) { FusionExecutor fe; int x = 7, y = 8, z = 9; @@ -11950,7 +12048,7 @@ __global__ void kernel1( TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); } -TEST(NVFuserTest, FusionGridWelfordNoInit_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordNoInit_CUDA) { FusionExecutor fe; int x = 128, y = 64, z = 128; @@ -12040,7 +12138,7 @@ __global__ void kernel1( TORCH_CHECK(in0.var(dims, false).allclose(out_var)); } -TEST(NVFuserTest, FusionWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12048,7 +12146,7 @@ TEST(NVFuserTest, FusionWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12069,7 +12167,7 @@ TEST(NVFuserTest, FusionWelfordOp_CUDA) { at::Tensor t0 = at::randn({M, N}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12084,7 +12182,7 @@ TEST(NVFuserTest, FusionWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12092,7 +12190,7 @@ TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12115,7 +12213,7 @@ TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { at::Tensor t_N = at::empty({M}, options_int); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12130,7 +12228,7 @@ TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12138,7 +12236,7 @@ TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12161,7 +12259,7 @@ TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { at::Tensor t_N = at::empty({M}, options_int); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12176,7 +12274,7 @@ TEST(NVFuserTest, FusionGridWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { +TEST_F(NVFuserTest, FusionRfactorWelfordOp_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12184,7 +12282,7 @@ TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12206,7 +12304,7 @@ TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { at::Tensor t_N = at::empty({M}, options_int); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // by default Welford outputs sum of square diff so need to divide to get var @@ -12221,7 +12319,7 @@ TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionWelfordSchedule_CUDA) { +TEST_F(NVFuserTest, FusionWelfordSchedule_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12229,7 +12327,7 @@ TEST(NVFuserTest, FusionWelfordSchedule_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = mul(tv0, new Double(1)); + auto tv1 = mul(tv0, IrBuilder::create(1)); auto tvs = Welford(tv1, {1}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12246,9 +12344,10 @@ TEST(NVFuserTest, FusionWelfordSchedule_CUDA) { auto reduction_params = getReductionHeuristics(&fusion, {t0}); scheduleReduction(&fusion, reduction_params.value()); + auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({t0}, reduction_params.value().lparams); + fe.compileFusion(&fusion, {t0}, lparams); + auto outputs = fe.runFusion({t0}, lparams); // by default Welford outputs sum of square diff so need to divide to get var outputs[1] /= N; @@ -12283,7 +12382,7 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { tv0_cast = castOp(DataType::Float, tv0); } fusion.addInput(tv0); - auto tv1 = mul(tv0_cast, new Double(1)); + auto tv1 = mul(tv0_cast, IrBuilder::create(1)); auto tvs = Welford(tv1, {axis}); auto tv_avg = tvs.avg; auto tv_M2 = tvs.var_sum; @@ -12324,8 +12423,8 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); - auto outputs = fe.runFusion({aten_input}, reduction_params.value().lparams); + fe.compileFusion(&fusion, {aten_input}, lparams); + auto outputs = fe.runFusion({aten_input}, lparams); // by default Welford outputs sum of square diff so need to divide to // get var @@ -12351,7 +12450,7 @@ void testWelford(DataType dtype, int red_axis, int odim, int rdim) { } } // namespace -TEST(NVFuserTest, FusionWelfordShmoo_CUDA) { +TEST_F(NVFuserTest, FusionWelfordShmoo_CUDA) { std::vector dtypes = { DataType::Double, DataType::Float, DataType::Half}; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 @@ -12393,7 +12492,7 @@ TEST(NVFuserTest, FusionWelfordShmoo_CUDA) { } } -TEST(NVFuserTest, FusionTranspose1_CUDA) { +TEST_F(NVFuserTest, FusionTranspose1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12414,7 +12513,7 @@ TEST(NVFuserTest, FusionTranspose1_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0.t(); @@ -12423,7 +12522,7 @@ TEST(NVFuserTest, FusionTranspose1_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTranspose2_CUDA) { +TEST_F(NVFuserTest, FusionTranspose2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12447,7 +12546,7 @@ TEST(NVFuserTest, FusionTranspose2_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0.t(); @@ -12456,7 +12555,7 @@ TEST(NVFuserTest, FusionTranspose2_CUDA) { &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { +TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12526,10 +12625,11 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { at::Tensor t0 = at::randn({K, M}, options); at::Tensor t1 = at::randn({N, K}, options); - FusionExecutor fe; - fe.compileFusion(&fusion); // Lets specify a few bounds in launch params to make sure it works - fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); + LaunchParams lparams(1, -1, -1, 32, 4, 4); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + fe.runFusion({t0, t1}, lparams); // Don't specify any launch params auto cg_outputs = fe.runFusion({t0, t1}); @@ -12540,7 +12640,7 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { +TEST_F(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12593,7 +12693,7 @@ TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_input_t = at::transpose(input, 1, 2); @@ -12603,7 +12703,7 @@ TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 @@ -12620,10 +12720,10 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { tv0 = transpose(tv0, {{0, 1}}); - TensorView* tv1 = mul(tv0, new Double(0.5)); - TensorView* tv2 = mul(tv1, new Double(-1.0)); - TensorView* tv3 = add(tv1, new Double(3.0)); - TensorView* tv4 = mul(tv1, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); @@ -12654,7 +12754,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { } for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); @@ -12667,7 +12767,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { at::Tensor aten_input = at::randn({129, 127}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); at::Tensor aten_input_t = aten_input.t(); @@ -12686,7 +12786,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 @@ -12702,9 +12802,9 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { tv0 = transpose(tv0, {{0, 1}}); - TensorView* tv1 = mul(tv0, new Double(-1.0)); - TensorView* tv2 = add(tv0, new Double(3.0)); - TensorView* tv3 = mul(tv0, new Double(2.0)); + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); @@ -12723,7 +12823,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { tv0->computeAt(tv6, 1); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -12736,7 +12836,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { at::Tensor input = at::randn({129, 127}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto cg_outputs = fe.runFusion({input}); auto input_t = input.t(); @@ -12752,7 +12852,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 @@ -12769,7 +12869,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}}); - TensorView* tv2 = mul(tv1, new Double(.979361)); + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); @@ -12786,7 +12886,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -12802,7 +12902,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t0_t = t0.permute({3, 0, 1, 2}); @@ -12814,7 +12914,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 @@ -12862,7 +12962,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { - if (!fusion.hasInput(val) && + if (!val->isFusionInput() && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); @@ -12880,7 +12980,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { std::vector aten_inputs = {t0, t1, t2, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t0_t = t0.permute({3, 0, 1, 2}); @@ -12895,7 +12995,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 @@ -12909,7 +13009,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); tv1 = transpose(tv1, {{0, 1}}); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -12928,7 +13028,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t2 = t0.t().add(2.0); @@ -12938,7 +13038,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -12948,7 +13048,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { TensorView* tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); tv1 = transpose(tv1, {{0, 1}}); - TensorView* tv2 = add(tv0, new Double(2.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); @@ -12970,7 +13070,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t2 = t0.t().add(2.0); @@ -12980,7 +13080,7 @@ TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) { +TEST_F(NVFuserTest, FusionSegmentReducePointwise_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -12992,7 +13092,7 @@ TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) { fusion->addInput(tv1); fusion->addInput(tv2); - TensorView* tv3 = add(tv0, new Double(1)); // Group 0 + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 TensorView* tv4 = max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, @@ -13029,7 +13129,7 @@ TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) { executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultipleVectorize_CUDA) { +TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -13089,7 +13189,7 @@ TEST(NVFuserTest, FusionMultipleVectorize_CUDA) { TORCH_CHECK(runtime1 != runtime3); } -TEST(NVFuserTest, FusionVectorizeSimple_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeSimple_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13123,7 +13223,7 @@ TEST(NVFuserTest, FusionVectorizeSimple_CUDA) { at::Tensor aten_input = at::empty({2, 6, 32}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {aten_input}); auto cg_outputs = fe.runFusion({aten_input}); at::Tensor aten_output = aten_input.sin(); @@ -13132,7 +13232,7 @@ TEST(NVFuserTest, FusionVectorizeSimple_CUDA) { &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { +TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { Fusion fusion; FusionGuard fg(&fusion); // dimensionality of the problem @@ -13148,7 +13248,7 @@ TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView - TensorView* tv2 = add(tv1, new Double(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs @@ -13197,7 +13297,7 @@ TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { at::Tensor output = at::empty_like(input1); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; @@ -13206,7 +13306,7 @@ TEST(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { TORCH_CHECK(output_ref.equal(output)); } -TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { +TEST_F(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -13220,7 +13320,7 @@ TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { fusion->addInput(tv0); - auto tv1 = add(tv0, new Double(1.0)); + auto tv1 = add(tv0, IrBuilder::create(1.0)); auto tv2 = sum(tv1, {2}); // Group 0 auto output = softmax(tv2, kReductionAxis); // Group 1 @@ -13247,14 +13347,14 @@ TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSwizzle1_CUDA) { +TEST_F(NVFuserTest, FusionSwizzle1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = mul(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = mul(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); tv2->split(0, 7); @@ -13279,7 +13379,7 @@ TEST(NVFuserTest, FusionSwizzle1_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = (t0 + 1) * 2; @@ -13288,14 +13388,14 @@ TEST(NVFuserTest, FusionSwizzle1_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSwizzle2_CUDA) { +TEST_F(NVFuserTest, FusionSwizzle2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = mul(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = mul(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -13323,7 +13423,7 @@ TEST(NVFuserTest, FusionSwizzle2_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = (t0 + 1) * 2; @@ -13332,7 +13432,7 @@ TEST(NVFuserTest, FusionSwizzle2_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { +TEST_F(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13385,8 +13485,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.t(); @@ -13395,7 +13494,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { +TEST_F(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13452,8 +13551,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); - + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.t(); @@ -13462,10 +13560,7 @@ TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridPersistence_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridPersistence_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13490,7 +13585,7 @@ TEST(NVFuserTest, FusionGridPersistence_CUDA) { at::Tensor input = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}).unsqueeze(-1).add(input); @@ -13498,10 +13593,7 @@ TEST(NVFuserTest, FusionGridPersistence_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridPersistence2_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionGridPersistence2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13528,7 +13620,7 @@ TEST(NVFuserTest, FusionGridPersistence2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = input.sum({0}).unsqueeze(0).add(input); @@ -13536,10 +13628,7 @@ TEST(NVFuserTest, FusionGridPersistence2_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWelfordPersistence_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionWelfordPersistence_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13567,7 +13656,7 @@ TEST(NVFuserTest, FusionWelfordPersistence_CUDA) { at::Tensor input = at::randn({numel_x}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) @@ -13577,10 +13666,7 @@ TEST(NVFuserTest, FusionWelfordPersistence_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWelfordPersistence2_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionWelfordPersistence2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13610,7 +13696,7 @@ TEST(NVFuserTest, FusionWelfordPersistence2_CUDA) { at::Tensor input = at::randn({numel_x, numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); auto out = fe.runFusion({input}); auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) @@ -13620,7 +13706,7 @@ TEST(NVFuserTest, FusionWelfordPersistence2_CUDA) { testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue633_CUDA) { +TEST_F(NVFuserTest, FusionIssue633_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13642,14 +13728,13 @@ TEST(NVFuserTest, FusionIssue633_CUDA) { tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dx, dy, dz}, options); at::Tensor t1 = at::randn({dx, dy, 1}, options); std::vector aten_inputs = {t0, t1}; + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13658,48 +13743,7 @@ TEST(NVFuserTest, FusionIssue633_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionKirScoping_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); - fusion.addOutput(tv2); - - tv2->merge(0); - tv2->split(0, 4); - tv0->computeAt(tv2, -1); - - GpuLower gpulw(&fusion); - - auto kir_tv1 = gpulw.lowerValue(tv1); - auto tv1_scope = kir_tv1->definition()->scope(); - TORCH_CHECK(tv1_scope != nullptr); - TORCH_CHECK(tv1_scope->owner()->as()); - - auto kir_tv2 = gpulw.lowerValue(tv2); - auto tv2_scope = kir_tv2->definition()->scope(); - TORCH_CHECK(tv2_scope != nullptr); - TORCH_CHECK(tv2_scope->owner()->as()); - - TORCH_CHECK(tv1_scope != tv2_scope); - - // tv1 and tv2 should have the same inner-most ForLoop - auto parent_scope = tv1_scope->owner()->scope(); - TORCH_CHECK(parent_scope == tv2_scope->owner()->scope()); - TORCH_CHECK(parent_scope->owner()->as()); - // There should be one more loop - parent_scope = parent_scope->owner()->scope(); - TORCH_CHECK(parent_scope->owner()->as()); - - // scope() should return nullptr for top-level exprs - auto top_level_scope = parent_scope->owner()->scope(); - TORCH_CHECK(top_level_scope == nullptr); -} - -TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { +TEST_F(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13726,7 +13770,7 @@ TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t3 = t0.unsqueeze(-1).expand(shape) + t1; @@ -13734,7 +13778,7 @@ TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13777,7 +13821,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13785,7 +13829,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13835,7 +13879,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13843,7 +13887,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13896,7 +13940,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -13904,7 +13948,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -13963,7 +14007,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14013,7 +14057,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.add(t1).sum(1); @@ -14021,7 +14065,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14059,7 +14103,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14100,7 +14144,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -14108,7 +14152,7 @@ TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { +TEST_F(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14151,13 +14195,13 @@ TEST(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); // Failure because the input + output tensors do not have the same stride ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); } -TEST(NVFuserTest, FusionViewOutput_CUDA) { +TEST_F(NVFuserTest, FusionViewOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14181,7 +14225,7 @@ TEST(NVFuserTest, FusionViewOutput_CUDA) { auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_add_bias = at_x + at_bias; @@ -14190,7 +14234,7 @@ TEST(NVFuserTest, FusionViewOutput_CUDA) { testValidate(&fusion, outputs, aten_inputs, {at_x_view}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionViewFailMismatchSize_CUDA) { +TEST_F(NVFuserTest, FusionViewFailMismatchSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14210,7 +14254,7 @@ TEST(NVFuserTest, FusionViewFailMismatchSize_CUDA) { ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); } -TEST(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { +TEST_F(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14228,7 +14272,7 @@ TEST(NVFuserTest, FusionViewFailMulitDimInference_CUDA) { ASSERT_ANY_THROW(view(x_add_bias, input_shape, output_shape)); } -TEST(NVFuserTest, FusionViewFailReduction_CUDA) { +TEST_F(NVFuserTest, FusionViewFailReduction_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -14259,7 +14303,7 @@ TEST(NVFuserTest, FusionViewFailReduction_CUDA) { ASSERT_ANY_THROW(fusion_executor_cache.runFusionWithInputs({at_x, at_bias})); } -TEST(NVFuserTest, FusionViewFailPersistent_CUDA) { +TEST_F(NVFuserTest, FusionViewFailPersistent_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -14319,7 +14363,7 @@ void addViewGeluFusion( auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_add_bias = at_x + at_bias; @@ -14330,25 +14374,25 @@ void addViewGeluFusion( } } -TEST(NVFuserTest, FusionViewSplit_CUDA) { +TEST_F(NVFuserTest, FusionViewSplit_CUDA) { std::vector input_shape{80}; std::vector output_shape{2, 4, 10}; addViewGeluFusion(input_shape, output_shape); } -TEST(NVFuserTest, FusionViewBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionViewBroadcast_CUDA) { std::vector input_shape{80}; std::vector output_shape{1, 80}; addViewGeluFusion(input_shape, output_shape); } -TEST(NVFuserTest, FusionViewMerge_CUDA) { +TEST_F(NVFuserTest, FusionViewMerge_CUDA) { std::vector input_shape{2, 40, 7}; std::vector output_shape{560}; addViewGeluFusion(input_shape, output_shape); } -TEST(NVFuserTest, FusionViewAllShmoo_CUDA) { +TEST_F(NVFuserTest, FusionViewAllShmoo_CUDA) { typedef std::vector shape; typedef std::pair view_example; @@ -14373,7 +14417,7 @@ TEST(NVFuserTest, FusionViewAllShmoo_CUDA) { } } -TEST(NVFuserTest, FusionViewInferShmoo_CUDA) { +TEST_F(NVFuserTest, FusionViewInferShmoo_CUDA) { typedef std::vector shape; typedef std::pair view_example; @@ -14427,7 +14471,7 @@ void geluViewAddFusion( auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_gelu = at::gelu(at_x); @@ -14438,7 +14482,7 @@ void geluViewAddFusion( } } -TEST(NVFuserTest, FusionViewStride_CUDA) { +TEST_F(NVFuserTest, FusionViewStride_CUDA) { typedef std::vector shape; typedef std::pair view_example; @@ -14483,7 +14527,7 @@ void geluViewBinaryAddFusion( auto lparams = schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs, lparams); auto outputs = fe.runFusion(aten_inputs, lparams); auto at_x_gelu = at::gelu(at_x); @@ -14495,11 +14539,11 @@ void geluViewBinaryAddFusion( } } -TEST(NVFuserTest, FusionViewBinary_CUDA) { +TEST_F(NVFuserTest, FusionViewBinary_CUDA) { geluViewBinaryAddFusion({27454, 2}, {54908}, {7844, 7}); } -TEST(NVFuserTest, FusionVectorization1_CUDA) { +TEST_F(NVFuserTest, FusionVectorization1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14540,7 +14584,7 @@ TEST(NVFuserTest, FusionVectorization1_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0 + t1; @@ -14548,7 +14592,7 @@ TEST(NVFuserTest, FusionVectorization1_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorization2_CUDA) { +TEST_F(NVFuserTest, FusionVectorization2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14586,7 +14630,7 @@ TEST(NVFuserTest, FusionVectorization2_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionVectorization3_CUDA) { +TEST_F(NVFuserTest, FusionVectorization3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14623,11 +14667,10 @@ TEST(NVFuserTest, FusionVectorization3_CUDA) { const int by = 2049; at::Tensor t0 = at::randn({bx, by}, options); at::Tensor t1 = at::randn({bx, by}, options); + std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); - - std::vector aten_inputs = {t0, t1}; + fe.compileFusion(&fusion, aten_inputs); ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); aten_inputs[0] = t0.index({"...", Slice(1)}); @@ -14644,7 +14687,7 @@ TEST(NVFuserTest, FusionVectorization3_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) { +TEST_F(NVFuserTest, FusionVectorizationRFactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14692,7 +14735,7 @@ TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) { std::vector aten_inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto aten_output = t0.add(t1).sum(1); @@ -14705,7 +14748,7 @@ TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) { } // Unswitched loops with extent one may omit else clause. -TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { +TEST_F(NVFuserTest, FusionSizeOneLoop1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14742,16 +14785,7 @@ TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { // Make sure the unswitched loop does not have an else clause. GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto fl = dynamic_cast(kir_node.get())) { - if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) { - continue; - } - if (auto pred = dynamic_cast(fl->parentScope())) { - TORCH_CHECK(!pred->hasElse()); - } - } - } + TORCH_CHECK(!UnswitchInElseChecker::check(gpulw)); const int x = 11; const int y = 12; @@ -14763,7 +14797,7 @@ TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { std::vector aten_inputs = {t0, t1, t2}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t6 = (t0.unsqueeze(-1) + t1).unsqueeze(0) + t2; @@ -14772,7 +14806,7 @@ TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) { // The unswitched loop has extent one but inner loops don't. The else // part should not be omitted. -TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) { +TEST_F(NVFuserTest, FusionSizeOneLoop2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14780,7 +14814,7 @@ TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) { auto tv0 = makeConcreteTensor({x}); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); tv1->split(-1, 4); @@ -14790,38 +14824,29 @@ TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) { // Make sure the size-one unswitched loop does not omit the else clause. GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto fl = dynamic_cast(kir_node.get())) { - if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) { - continue; - } - if (auto pred = dynamic_cast(fl->parentScope())) { - TORCH_CHECK(pred->hasElse()); - } - } - } + TORCH_CHECK(UnswitchInElseChecker::check(gpulw)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x}, options); std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); auto t1 = t0 + 1; testValidate(&fusion, cg_outputs, aten_inputs, {t1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionValidateParallelize1_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -14832,15 +14857,15 @@ TEST(NVFuserTest, FusionValidateParallelize1_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionValidateParallelize2_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -14853,15 +14878,15 @@ TEST(NVFuserTest, FusionValidateParallelize2_CUDA) { fe.compileFusion(&fusion); } -TEST(NVFuserTest, FusionValidateParallelize3_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -14876,15 +14901,15 @@ TEST(NVFuserTest, FusionValidateParallelize3_CUDA) { fe.compileFusion(&fusion); } -TEST(NVFuserTest, FusionValidateParallelize4_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -14899,15 +14924,15 @@ TEST(NVFuserTest, FusionValidateParallelize4_CUDA) { ASSERT_ANY_THROW(fe.compileFusion(&fusion)); } -TEST(NVFuserTest, FusionValidateParallelize5_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv1->split(-1, 4); @@ -14924,7 +14949,7 @@ TEST(NVFuserTest, FusionValidateParallelize5_CUDA) { } // See issue #995 -TEST(NVFuserTest, FusionValidateParallelize6_CUDA) { +TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14933,7 +14958,7 @@ TEST(NVFuserTest, FusionValidateParallelize6_CUDA) { fusion.addInput(tv0); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -14960,7 +14985,7 @@ TEST(NVFuserTest, FusionValidateParallelize6_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -TEST(NVFuserTest, FusionDAGMerging_CUDA) { +TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14976,7 +15001,7 @@ TEST(NVFuserTest, FusionDAGMerging_CUDA) { auto tv5 = sum(tv4, {0}); // 3 // Branch 1 - auto tv6 = add(tv1, new Double(1)); // 4 + auto tv6 = add(tv1, IrBuilder::create(1)); // 4 // Merge auto tv7 = add(tv6, tv5); // 5 @@ -14995,17 +15020,17 @@ TEST(NVFuserTest, FusionDAGMerging_CUDA) { TORCH_CHECK(fusion_segments->groups().size() <= 4); } -TEST(NVFuserTest, FusionDAGScalarMerging_CUDA) { +TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(3); - auto i0 = new Double(); + auto i0 = IrBuilder::create(); fusion->addInput(tv0); fusion->addInput(i0); - auto i1 = add(i0, new Double(1.0)); + auto i1 = add(i0, IrBuilder::create(1.0)); auto i2 = mul(i1, i1); auto i3 = add(i2, i1); @@ -15051,7 +15076,7 @@ TEST(NVFuserTest, FusionDAGScalarMerging_CUDA) { executor_cache.fusion(), outputs, {t0, s0}, {t5}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { +TEST_F(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15073,14 +15098,14 @@ TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_output = t0.sum({1, 2}); testValidate( &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { +TEST_F(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15106,7 +15131,7 @@ TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); at::Tensor aten_avg = t0.mean({1, 2}); at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K; @@ -15115,7 +15140,7 @@ TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { } // See Issue #716 -TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { +TEST_F(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15129,7 +15154,7 @@ TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { std::vector broadcast_mask = {false, true}; auto tv0_bcast = broadcast(tv0, broadcast_mask); - auto path1_bcast = add(tv0_bcast, new Double(1.0)); + auto path1_bcast = add(tv0_bcast, IrBuilder::create(1.0)); auto path1 = sum(path1_bcast, reduction_axes); fusion.addOutput(path1); @@ -15145,7 +15170,7 @@ TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { std::vector aten_inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); // inplace op, we are adding t0 to itself auto outputs = fe.runFusion(aten_inputs, {t0}); @@ -15153,7 +15178,7 @@ TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { TORCH_CHECK(outputs[0].allclose(t0_ref.add(1))); } -TEST(NVFuserTest, FusionReductionPredicate_CUDA) { +TEST_F(NVFuserTest, FusionReductionPredicate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15184,7 +15209,7 @@ TEST(NVFuserTest, FusionReductionPredicate_CUDA) { at::Tensor cg_output = at::empty({numel_y}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}); fe.runFusion({input}, {cg_output}); auto aten_output = input.to(at::kDouble).sum({0}); @@ -15193,7 +15218,7 @@ TEST(NVFuserTest, FusionReductionPredicate_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue728_CUDA) { +TEST_F(NVFuserTest, FusionIssue728_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15204,10 +15229,10 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { auto tv2 = makeSymbolicTensor(1); fusion.addOutput(tv2); - auto tv3 = add(tv0, new Double(1)); + auto tv3 = add(tv0, IrBuilder::create(1)); auto tv4 = add(tv3, tv1); - auto tv5 = add(tv4, new Double(1)); - auto tv6 = add(tv2, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv5); fusion.addOutput(tv6); @@ -15253,7 +15278,7 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { "Only tv3 should be included"); } -TEST(NVFuserTest, FusionIssue757_CUDA) { +TEST_F(NVFuserTest, FusionIssue757_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15281,7 +15306,7 @@ TEST(NVFuserTest, FusionIssue757_CUDA) { std::vector inputs = {t0, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0.sum({1}); @@ -15292,7 +15317,7 @@ TEST(NVFuserTest, FusionIssue757_CUDA) { } // See issue #759 -TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15323,7 +15348,7 @@ TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { std::vector inputs = {t0, t3}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0.sum({1}); @@ -15333,7 +15358,7 @@ TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { +TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -15367,12 +15392,12 @@ TEST(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { TORCH_CHECK(segmented_fusion->groups().size() == 2); } -TEST(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { +TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(3); - auto i0 = new Double(); + auto i0 = IrBuilder::create(); fusion->addInput(tv0); fusion->addInput(i0); @@ -15407,7 +15432,7 @@ TEST(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { TORCH_CHECK(segmented_fusion->groups().size() == 2); } -TEST(NVFuserTest, FusionSegmentMixReduction_CUDA) { +TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -15446,7 +15471,7 @@ TEST(NVFuserTest, FusionSegmentMixReduction_CUDA) { TORCH_CHECK(segmented_fusion->groups().size() <= 2); } -TEST(NVFuserTest, FusionSBAR_CUDA) { +TEST_F(NVFuserTest, FusionSBAR_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15492,12 +15517,11 @@ TEST(NVFuserTest, FusionSBAR_CUDA) { // outputs std::vector outputs; - auto lparams = schedulePointwise(&fusion, c10::ArrayRef(inputs)); + auto lparams = schedulePointwise(&fusion, inputs); FusionExecutor executor; - executor.compileFusion(&fusion); - - outputs = executor.runFusion(c10::ArrayRef(inputs), lparams); + executor.compileFusion(&fusion, inputs, lparams); + outputs = executor.runFusion(inputs, lparams); auto at_scale = at::mul(at_x, at_weight); auto at_scale_bias = at::add(at_scale, at_bias); @@ -15507,16 +15531,16 @@ TEST(NVFuserTest, FusionSBAR_CUDA) { testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSingleElement_CUDA) { +TEST_F(NVFuserTest, FusionSingleElement_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(0); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(2.5)); + auto tv1 = add(tv0, IrBuilder::create(2.5)); - auto tv2 = add(tv1, new Double(3.5)); + auto tv2 = add(tv1, IrBuilder::create(3.5)); fusion.addOutput(tv2); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -15527,7 +15551,7 @@ TEST(NVFuserTest, FusionSingleElement_CUDA) { auto lparams = schedulePointwise(&fusion, {input}); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input}, lparams); fe.runFusion({input}, {cg_output}, lparams); auto aten_output = input.add(2.5).add(3.5); @@ -15536,7 +15560,7 @@ TEST(NVFuserTest, FusionSingleElement_CUDA) { &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) { +TEST_F(NVFuserTest, FusionBNBackwardRepro_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15566,12 +15590,12 @@ TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) { makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous. fusion.addInput(gt_0); - auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1)); + auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); auto gt_float = castOp(DataType::Float, gt_bool); auto grad_out = mul(grad_out_prev, gt_float); - Val* eps_ptr = new Double(1e-5); + Val* eps_ptr = IrBuilder::create(1e-5); auto grads = batch_norm_backward( input, @@ -15606,7 +15630,7 @@ TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) { } // TODO: We only changed inputs, merge this with the test above. -TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) { +TEST_F(NVFuserTest, FusionBNBackwardRepro2_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15639,12 +15663,12 @@ TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) { auto gt_0 = makeConcreteTensor({-1, -1, 1, 1}); fusion.addInput(gt_0); - auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1)); + auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); auto gt_float = castOp(DataType::Float, gt_bool); auto grad_out = mul(grad_out_prev, gt_float); - Val* eps_ptr = new Double(1e-5); + Val* eps_ptr = IrBuilder::create(1e-5); auto grads = batch_norm_backward( input, @@ -15678,7 +15702,7 @@ TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) { auto outputs = fec.runFusionWithInputs(inputs); } -TEST(NVFuserTest, FusionBNRepro_CUDA) { +TEST_F(NVFuserTest, FusionBNRepro_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15704,8 +15728,8 @@ TEST(NVFuserTest, FusionBNRepro_CUDA) { auto running_var = makeSymbolicTensor(1); fusion.addInput(running_var); - auto momentum_ptr = new Double(kMomentum); - auto eps_ptr = new Double(kEps); + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); auto result = batch_norm( input, @@ -15759,7 +15783,7 @@ TEST(NVFuserTest, FusionBNRepro_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBNRepro2_CUDA) { +TEST_F(NVFuserTest, FusionBNRepro2_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -15777,8 +15801,8 @@ TEST(NVFuserTest, FusionBNRepro2_CUDA) { auto input = makeSymbolicTensor(numDims); fusion.addInput(input); - Val* momentum_ptr = new Double(kMomentum); - Val* eps_ptr = new Double(kEps); + Val* momentum_ptr = IrBuilder::create(kMomentum); + Val* eps_ptr = IrBuilder::create(kEps); auto result = batch_norm( input, @@ -15820,7 +15844,7 @@ TEST(NVFuserTest, FusionBNRepro2_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { +TEST_F(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15830,7 +15854,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { auto tv1 = makeConcreteTensor({0}); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(2.5)); + auto tv2 = add(tv0, IrBuilder::create(2.5)); fusion.addOutput(tv2); auto tv3 = makeConcreteTensor({0}); @@ -15846,7 +15870,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { auto lparams = schedulePointwise(&fusion, {input0, input1}); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}); fe.runFusion({input0, input1}, {cg_output2, cg_output3}, lparams); auto aten_output2 = input0.add(2.5); @@ -15861,7 +15885,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { +TEST_F(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15891,7 +15915,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}, lparams); auto cg_outputs = fe.runFusion({input0, input1}, lparams); auto aten_output2 = input0.sum({1}); at::Tensor aten_output3 = at::empty({0}, options); @@ -15907,7 +15931,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { lparams); } -TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { +TEST_F(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -15938,7 +15962,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { auto lparams = reduction_params.value().lparams; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input0, input1}, lparams); auto cg_outputs = fe.runFusion({input0, input1}, lparams); auto aten_output2 = input0.sum({0}).add(input0); at::Tensor aten_output3 = at::empty({0}, options); @@ -15954,7 +15978,7 @@ TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { lparams); } -TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) { +TEST_F(NVFuserTest, FusionSegmentIoAlias_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -15966,7 +15990,7 @@ TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) { fusion->addInput(tv1); fusion->addInput(tv2); - TensorView* tv3 = add(tv0, new Double(1)); // Group 0 + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 TensorView* tv4 = max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, @@ -16008,7 +16032,7 @@ TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) { executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWelford1Output_CUDA) { +TEST_F(NVFuserTest, FusionWelford1Output_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16028,7 +16052,7 @@ TEST(NVFuserTest, FusionWelford1Output_CUDA) { testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { +TEST_F(NVFuserTest, FusionTranslate1Welford_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16037,7 +16061,8 @@ TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { fusion->addInput(tv0); auto tvs = Welford(tv0, {1}); - fusion->addOutput(tvs.var_sum); + auto tv_out = add(tv0, broadcast(tvs.avg, {false, true})); + fusion->addOutput(tv_out); FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto run_test = [&executor_cache, @@ -16047,9 +16072,13 @@ TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { auto outputs = executor_cache.runFusionWithInputs({t0}); // Square sums does not fit well in the testValidate assumptions, // so we just compare the divided output here. - outputs[0] /= inner_size; - auto t1 = t0.var({1}, false); - testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); + testValidate( + fusion, + outputs, + {t0}, + {t0.add(t0.mean({1}).unsqueeze(1))}, + __LINE__, + __FILE__); return executor_cache.getMostRecentKernelRuntime(); }; @@ -16057,21 +16086,25 @@ TEST(NVFuserTest, FusionTranslate1Welford_CUDA) { // Run a translated welford auto runtime1 = run_test(64); // Check it was translated - TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 2); TORCH_CHECK( - runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() == - ScheduleHeuristic::Persistent); + runtime1->fusionSegments()->groups().size() == 1 && + runtime1->fusionSegments()->groups()[0]->exprs().size() > 2); // Run an un-translated welford auto runtime2 = run_test(65536); - // Check it was not translated - TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 1); - TORCH_CHECK( - runtime2->schedulerHeuristics()->singleKernelHeuristics()->heuristc() == - ScheduleHeuristic::Reduction); + + bool found_welford = false; + for (auto group : runtime2->fusionSegments()->groups()) { + for (auto expr : group->exprs()) { + if (expr->isA()) { + found_welford = true; + } + } + } + TORCH_CHECK(found_welford); } -TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { +TEST_F(NVFuserTest, FusionTranslate2Welford_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16080,10 +16113,12 @@ TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { fusion->addInput(tv0); auto tvs1 = Welford(tv0, {1}); - auto tvs2 = Welford(tv0, {1}); + auto tv_out1 = add(tv0, broadcast(tvs1.avg, {false, true})); + fusion->addOutput(tv_out1); - fusion->addOutput(tvs1.var_sum); - fusion->addOutput(tvs2.var_sum); + auto tvs2 = Welford(tv0, {1}); + auto tv_out2 = add(tv0, broadcast(tvs2.avg, {false, true})); + fusion->addOutput(tv_out2); FusionExecutorCache executor_cache(std::move(fusion_ptr)); @@ -16095,10 +16130,8 @@ TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { // Square sums does not fit well in the testValidate assumptions, // so we just compare the divided output here. - outputs[0] /= inner_size; - outputs[1] /= inner_size; - auto t1 = t0.var({1}, false); - testValidate(fusion, outputs, {t0}, {t1, t1}, __LINE__, __FILE__); + auto out = t0.add(t0.mean({1}).unsqueeze(1)); + testValidate(fusion, outputs, {t0}, {out, out}, __LINE__, __FILE__); return executor_cache.getMostRecentKernelRuntime(); }; @@ -16106,18 +16139,25 @@ TEST(NVFuserTest, FusionTranslate2Welford_CUDA) { // Run a translated welford auto runtime1 = run_test(64); // Check it was translated - TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 4); TORCH_CHECK( - runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() == - ScheduleHeuristic::Persistent); + runtime1->fusionSegments()->groups().size() == 1 && + runtime1->fusionSegments()->groups()[0]->exprs().size() > 4); // Run an un-translated welford auto runtime2 = run_test(65536); // // Check it was not translated - TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 2); + bool found_welford = false; + for (auto group : runtime2->fusionSegments()->groups()) { + for (auto expr : group->exprs()) { + if (expr->isA()) { + found_welford = true; + } + } + } + TORCH_CHECK(found_welford); } -TEST(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { +TEST_F(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16150,7 +16190,7 @@ TEST(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { TORCH_CHECK(!runtime->isSegmented()); } -TEST(NVFuserTest, FusionWelfordOtherPersistence_CUDA) { +TEST_F(NVFuserTest, FusionWelfordOtherPersistence_CUDA) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16176,20 +16216,22 @@ TEST(NVFuserTest, FusionWelfordOtherPersistence_CUDA) { at::Tensor t0 = at::randn({128, inner_size}, options); auto outputs = executor_cache.runFusionWithInputs({t0}); - auto t1 = t0.mean({1}).unsqueeze(1) + t0; - auto t2 = t0.sum({1}).unsqueeze(1) + t0; + auto t1 = t0.to(c10::kDouble).mean({1}).unsqueeze(1) + t0; + auto t2 = t0.to(c10::kDouble).sum({1}).unsqueeze(1) + t0; testValidate(fusion, outputs, {t0}, {t2, t1}, __LINE__, __FILE__); return executor_cache.getMostRecentKernelRuntime(); }; for (auto inner_size : {4096, 8192, 32768}) { - auto runtime = run_test(4096); - TORCH_CHECK(!runtime->isSegmented()); + auto runtime = run_test(inner_size); + TORCH_CHECK( + !runtime->isSegmented() || + runtime->fusionSegments()->groups().size() == 1); } } -TEST(NVFuserTest, FusionSegmentIslands_CUDA) { +TEST_F(NVFuserTest, FusionSegmentIslands_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16211,7 +16253,7 @@ TEST(NVFuserTest, FusionSegmentIslands_CUDA) { fusion_executor_cache.runFusionWithInputs({t0, t1}); } -TEST(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16247,7 +16289,7 @@ TEST(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { TORCH_CHECK(tv8->getMaxProducerPosition() == 2); } -TEST(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16267,7 +16309,7 @@ TEST(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { TORCH_CHECK(tv3->getMaxProducerPosition() == 2); } -TEST(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16286,7 +16328,7 @@ TEST(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { TORCH_CHECK(tv3->getMaxProducerPosition() == 3); } -TEST(NVFuserTest, FusionSimpleWarp_CUDA) { +TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16317,14 +16359,14 @@ TEST(NVFuserTest, FusionSimpleWarp_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSimpleWarpPad_CUDA) { +TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16365,13 +16407,13 @@ TEST(NVFuserTest, FusionSimpleWarpPad_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { +TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16409,13 +16451,13 @@ TEST(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { auto at_output = input1.sum({1, 2}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSerialWarpReduction_CUDA) { +TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16450,13 +16492,13 @@ TEST(NVFuserTest, FusionSerialWarpReduction_CUDA) { auto at_output = input1.sum({1, 2}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTrivialWarpReduction_CUDA) { +TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16494,13 +16536,13 @@ TEST(NVFuserTest, FusionTrivialWarpReduction_CUDA) { auto at_output = input1.sum({1, 2, 3}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionMultipleDimBinding_CUDA) { +TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16548,7 +16590,7 @@ TEST(NVFuserTest, FusionMultipleDimBinding_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1, input2}); auto outputs = fe.runFusion({input1, input2}); testValidate( fusion.get(), @@ -16559,7 +16601,7 @@ TEST(NVFuserTest, FusionMultipleDimBinding_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionPadNoWarpReduce_CUDA) { +TEST_F(NVFuserTest, FusionPadNoWarpReduce_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16588,19 +16630,19 @@ TEST(NVFuserTest, FusionPadNoWarpReduce_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { +TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(2); fusion->addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); fusion->addOutput(tv2); @@ -16623,13 +16665,13 @@ TEST(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { auto at_output = (input1 + 1).sum({1}); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { +TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -16673,13 +16715,13 @@ TEST(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { auto at_output = input1.sum({1}, true).add(input1); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegfaultReduction_CUDA) { +TEST_F(NVFuserTest, FusionSegfaultReduction_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -16698,7 +16740,7 @@ TEST(NVFuserTest, FusionSegfaultReduction_CUDA) { std::vector at_sum_axes; std::vector outer_reduction_axes; std::vector outer_broadcast_mask(numDims, false); - Val* N = new Double(1); + Val* N = IrBuilder::create(1); for (const auto axis : c10::irange(numDims)) { if (axis != 1) { outer_reduction_axes.push_back(axis); @@ -16728,16 +16770,16 @@ TEST(NVFuserTest, FusionSegfaultReduction_CUDA) { &fusion, outputs, inputs, {at_output0, at_output1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPredicateElimination_CUDA) { +TEST_F(NVFuserTest, FusionPredicateElimination_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); - auto tv3 = add(tv2, new Double(3)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); + auto tv3 = add(tv2, IrBuilder::create(3)); fusion.addOutput(tv3); @@ -16748,7 +16790,7 @@ TEST(NVFuserTest, FusionPredicateElimination_CUDA) { { GpuLower gpulw(&fusion); - TORCH_CHECK(!isPredicated(tv2, gpulw)); + TORCH_CHECK(!PredicatedChecker::isPredicated(tv2, gpulw)); } tv2->axis(1)->parallelize(ParallelType::Serial); @@ -16756,11 +16798,11 @@ TEST(NVFuserTest, FusionPredicateElimination_CUDA) { { GpuLower gpulw(&fusion); - TORCH_CHECK(isPredicated(tv2, gpulw)); + TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); } } -TEST(NVFuserTest, FusionForceFp16Simple_CUDA) { +TEST_F(NVFuserTest, FusionForceFp16Simple_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16798,53 +16840,55 @@ TEST(NVFuserTest, FusionForceFp16Simple_CUDA) { } } -TEST(NVFuserTest, FusionForceBf16Simple_CUDA) { +TEST_F(NVFuserTest, FusionForceBf16Simple_CUDA) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } + + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); - fusion->addInput(tv0); - fusion->addInput(tv1); + // Group 1 + auto tv2 = sum(tv0, {1}); + auto tv3 = broadcast(tv2, {false, true}); - // Group 1 - auto tv2 = sum(tv0, {1}); - auto tv3 = broadcast(tv2, {false, true}); + // Group 2 + auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast + auto tv5 = castOp(DataType::BFloat16, tv4); - // Group 2 - auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast - auto tv5 = castOp(DataType::BFloat16, tv4); + fusion->addOutput(tv5); - fusion->addOutput(tv5); + FusionExecutorCache fec(std::move(fusion_ptr)); - FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector shape{15, 16}; - std::vector shape{15, 16}; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - // Check the segmented edge is bf16 - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - TORCH_CHECK(edge_tv->getDataType() == DataType::BFloat16); - } - } else { - GTEST_SKIP(); + // Check the segmented edge is bf16 + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + TORCH_CHECK(edge_tv->getDataType() == DataType::BFloat16); } #else - GTEST_SKIP(); + GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; #endif } -TEST(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { +TEST_F(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16893,64 +16937,66 @@ TEST(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { } } -TEST(NVFuserTest, FusionForceBf16NotAllCast_CUDA) { +TEST_F(NVFuserTest, FusionForceBf16NotAllCast_CUDA) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); - fusion->addInput(tv0); - fusion->addInput(tv1); + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); - // Group 1 - auto tv3 = sum(tv0, {1}); - auto tv4 = broadcast(tv3, {false, true, false}); - auto tv5 = sum(tv0, {1}); + fusion->addInput(tv0); + fusion->addInput(tv1); - // Group 2 - auto tv6 = add(tv4, tv1); // edge tv4, expect cast - auto tv7 = castOp(DataType::BFloat16, tv6); + // Group 1 + auto tv3 = sum(tv0, {1}); + auto tv4 = broadcast(tv3, {false, true, false}); + auto tv5 = sum(tv0, {1}); - // Group 3 - auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast + // Group 2 + auto tv6 = add(tv4, tv1); // edge tv4, expect cast + auto tv7 = castOp(DataType::BFloat16, tv6); - fusion->addOutput(tv7); - fusion->addOutput(tv8); + // Group 3 + auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast - FusionExecutorCache fec(std::move(fusion_ptr)); + fusion->addOutput(tv7); + fusion->addOutput(tv8); - std::vector shape{16, 16, 16}; + FusionExecutorCache fec(std::move(fusion_ptr)); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - auto complete_fusion = segmented_fusion->completeFusion(); - - // Check that the edge that wasn't fp16 is the producer of the - // reduction op, i.e. tv8 = sum(tv5,{1});. - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - if (edge_tv->getDataType() == DataType::Float) { - auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); - TORCH_CHECK(consumer->isA()); - } + std::vector shape{16, 16, 16}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); + + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + auto complete_fusion = segmented_fusion->completeFusion(); + + // Check that the edge that wasn't fp16 is the producer of the + // reduction op, i.e. tv8 = sum(tv5,{1});. + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + if (edge_tv->getDataType() == DataType::Float) { + auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); + TORCH_CHECK(consumer->isA()); } - } else { - GTEST_SKIP(); } #else - GTEST_SKIP(); + GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; #endif } -TEST(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16961,10 +17007,10 @@ TEST(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); + auto tv2 = mul(tv0, IrBuilder::create(2)); auto tv3 = broadcast(tv2, {false, false, true}); auto tv4 = add(tv3, tv1); - auto tv5 = mul(tv4, new Double(3)); + auto tv5 = mul(tv4, IrBuilder::create(3)); fusion->addOutput(tv5); // t4 cannot inner re-use t2, because there's a broadcast @@ -16978,13 +17024,13 @@ TEST(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto outputs = fe.runFusion({in0, in1}); testValidate(fusion, outputs, {in0, in1}, {at_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseStressTest_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -16995,17 +17041,17 @@ TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); - auto tv3 = mul(tv0, new Double(3)); + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = mul(tv0, IrBuilder::create(3)); auto tv4 = mul(tv2, tv3); // Broadcast buffer can be reused through outer sharing auto tv5 = broadcast(tv4, {true, false, false}); - auto tv6 = mul(tv5, new Double(5)); + auto tv6 = mul(tv5, IrBuilder::create(5)); auto tv7 = mul(tv6, tv1); - auto tv8 = mul(tv7, new Double(7)); + auto tv8 = mul(tv7, IrBuilder::create(7)); // tv9 shouldn't alias to avoid buffer over-subscription auto tv9 = broadcast(tv4, {true, false, false}); - auto tv10 = mul(tv9, new Double(9)); + auto tv10 = mul(tv9, IrBuilder::create(9)); auto tv11 = add(tv5, tv9); fusion->addOutput(tv7); fusion->addOutput(tv11); @@ -17031,7 +17077,7 @@ TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { auto t10 = t9 * 9; auto t11 = t5 + t9; FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; auto outputs = fe.runFusion({in0, in1}); @@ -17039,7 +17085,7 @@ TEST(NVFuserTest, FusionBufferReuseStressTest_CUDA) { testValidate(fusion, outputs, {in0, in1}, {t7, t11}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17048,12 +17094,12 @@ TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { fusion->addInput(tv0); - auto tv1 = mul(tv0, new Double(2)); - auto tv2 = mul(tv1, new Double(2)); - auto tv3 = mul(tv2, new Double(2)); - auto tv4 = mul(tv3, new Double(2)); - auto tv5 = mul(tv4, new Double(2)); - auto tv6 = mul(tv5, new Double(2)); + auto tv1 = mul(tv0, IrBuilder::create(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); + auto tv4 = mul(tv3, IrBuilder::create(2)); + auto tv5 = mul(tv4, IrBuilder::create(2)); + auto tv6 = mul(tv5, IrBuilder::create(2)); fusion->addOutput(tv6); @@ -17064,7 +17110,7 @@ TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { auto in0 = at::randn({256, 512}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0}); auto outputs = fe.runFusion({in0}); auto at_out = in0.mul(2).mul(2).mul(2).mul(2).mul(2).mul(2); @@ -17072,7 +17118,7 @@ TEST(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17083,12 +17129,12 @@ TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); + auto tv2 = mul(tv0, IrBuilder::create(2)); auto tv3 = broadcast(tv2, {false, false, true}); auto tv4 = add(tv3, tv1); // T4 to be inner aliased first, and // shouldn't outer alias on top - auto tv5 = mul(tv4, new Double(3)); - auto tv6 = mul(tv5, new Double(3)); + auto tv5 = mul(tv4, IrBuilder::create(3)); + auto tv6 = mul(tv5, IrBuilder::create(3)); fusion->addOutput(tv6); tv0->computeAt(tv6, 1, ComputeAtMode::BestEffort); @@ -17098,7 +17144,7 @@ TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { auto in0 = at::randn({2, 2}, options); auto in1 = at::randn({2, 2, 2}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto outputs = fe.runFusion({in0, in1}); auto at_out = (in0.mul(2.0).unsqueeze(2) + in1).mul(3.0).mul(3.0); @@ -17106,7 +17152,7 @@ TEST(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { testValidate(fusion, outputs, {in0, in1}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17116,8 +17162,8 @@ TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { fusion->addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = mul(tv1, new Double(2)); - auto tv3 = mul(tv2, new Double(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); fusion->addOutput(tv3); @@ -17134,7 +17180,7 @@ TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { auto in0 = at::randn({3, 3, 3}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0}); auto outputs = fe.runFusion({in0}); auto at_out = in0.sum(1).mul(2).mul(2); @@ -17142,7 +17188,7 @@ TEST(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17151,9 +17197,9 @@ TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { fusion->addInput(tv0); - auto tv1 = mul(tv0, new Double(3)); - auto tv2 = mul(tv1, new Double(2)); - auto tv3 = mul(tv2, new Double(2)); + auto tv1 = mul(tv0, IrBuilder::create(3)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); // tv1 used till here, cannot be reused by tv2 or tv3 auto tv4 = mul(tv3, tv1); @@ -17165,7 +17211,7 @@ TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { auto in0 = at::randn({16, 16}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0}); auto cg_outputs = fe.runFusion({in0}); auto at_t0 = in0 * 3.0; @@ -17174,7 +17220,7 @@ TEST(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { testValidate(fusion, cg_outputs, {in0}, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -17185,12 +17231,12 @@ TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = mul(tv0, new Double(2)); - auto tv3 = mul(tv0, new Double(3)); + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = mul(tv0, IrBuilder::create(3)); auto tv4 = mul(tv2, tv3); auto tv5 = broadcast(tv4, {false, false, true}); auto tv6 = mul(tv5, tv1); - auto tv7 = mul(tv6, new Double(7)); + auto tv7 = mul(tv6, IrBuilder::create(7)); fusion->addOutput(tv7); // tv6 shouldn't re-use t2 or t3 because of @@ -17202,7 +17248,7 @@ TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { auto in0 = at::randn({2, 2}, options); auto in1 = at::randn({2, 2, 2}, options); FusionExecutor fe; - fe.compileFusion(fusion); + fe.compileFusion(fusion, {in0, in1}); auto outputs = fe.runFusion({in0, in1}); auto t2 = in0 * 2; @@ -17214,7 +17260,7 @@ TEST(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { testValidate(fusion, outputs, {in0, in1}, {t7}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue970_CUDA) { +TEST_F(NVFuserTest, FusionIssue970_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17230,14 +17276,13 @@ TEST(NVFuserTest, FusionIssue970_CUDA) { tv1->split(1, 4); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); at::manual_seed(0); at::Tensor t0 = at::randn({nelm, nelm}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); auto ref = sum(t0, {1}).unsqueeze(-1).expand({nelm, nelm}) + t0; @@ -17246,15 +17291,15 @@ TEST(NVFuserTest, FusionIssue970_CUDA) { } // Reproducer of #1016 -TEST(NVFuserTest, FusionIssue1016_CUDA) { +TEST_F(NVFuserTest, FusionIssue1016_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); fusion.addOutput(tv2); @@ -17262,15 +17307,15 @@ TEST(NVFuserTest, FusionIssue1016_CUDA) { tv2->split(-1, 8); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 10; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0 + 1 + 2; @@ -17279,13 +17324,13 @@ TEST(NVFuserTest, FusionIssue1016_CUDA) { } // Reproducer of #1021 -TEST(NVFuserTest, FusionIssue1021_CUDA) { +TEST_F(NVFuserTest, FusionIssue1021_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = broadcast(tv1, {false, true}); fusion.addOutput(tv2); @@ -17298,12 +17343,12 @@ TEST(NVFuserTest, FusionIssue1021_CUDA) { tv2->axis(0)->parallelize(ParallelType::TIDx); tv2->axis(1)->parallelize(ParallelType::Vectorize); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = (t0 + 1).unsqueeze(-1); @@ -17312,7 +17357,7 @@ TEST(NVFuserTest, FusionIssue1021_CUDA) { } // Reproducer of issue #1053 -TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { +TEST_F(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -17321,7 +17366,7 @@ TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { auto tv1 = sum(tv0, {0}); fusion->addOutput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv2); tv1->split(0, 8); @@ -17340,20 +17385,20 @@ TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { auto at_tv2 = input1 + 1; FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( fusion.get(), outputs, {input1}, {at_tv1, at_tv2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap1_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(1); fusion->addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv1); fusion->addOutput(tv2); @@ -17366,25 +17411,22 @@ TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) { // actual values are not statically known GpuLower gpulw(fusion.get()); const auto& pdmap = gpulw.parallelDimensionMap(); - auto kir_tv1 = gpulw.lowerValue(tv1)->as(); - auto kir_tv2 = gpulw.lowerValue(tv2)->as(); - for (const auto i : c10::irange(kir_tv1->domain()->domain().size())) { - auto dom1 = kir_tv1->domain()->domain()[i]; - auto dom2 = kir_tv2->domain()->domain()[i]; + for (const auto i : c10::irange(tv1->domain()->domain().size())) { + auto dom1 = tv1->domain()->domain()[i]; + auto dom2 = tv2->domain()->domain()[i]; TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent())); } TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({32}, options); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( @@ -17396,7 +17438,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap2_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -17418,16 +17460,15 @@ TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) { const auto& pdmap = gpulw.parallelDimensionMap(); TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({11}, options); at::Tensor input2 = at::randn({11, 13}, options); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1, input2}); auto outputs = fe.runFusion({input1, input2}); auto ref = input1.unsqueeze(-1) + input2; @@ -17437,24 +17478,24 @@ TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) { } // Mix symbolic and concrete tensors -TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap3_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeSymbolicTensor(1); fusion->addInput(tv0); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv2); - auto tv3 = add(tv0, new Double(1)); + auto tv3 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv3); tv2->split(0, 10); tv3->split(0, 20); - auto tv4 = add(tv0, new Double(1)); + auto tv4 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv4); - auto tv5 = add(tv0, new Double(1)); + auto tv5 = add(tv0, IrBuilder::create(1)); fusion->addOutput(tv5); // Not mapped but equal extent @@ -17471,19 +17512,18 @@ TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) { const auto& pdmap = gpulw.parallelDimensionMap(); TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); TORCH_CHECK( pdmap.get(ParallelType::TIDy)->isConst() && - pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); + pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); FusionExecutor fe; - fe.compileFusion(fusion.get()); + fe.compileFusion(fusion.get(), {input1}); auto outputs = fe.runFusion({input1}); testValidate( @@ -17496,7 +17536,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) { } // Parallelizing merged broadcast domains -TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17504,7 +17544,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { fusion.addInput(tv0); auto tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); @@ -17526,16 +17566,15 @@ TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { const auto& pdmap = gpulw.parallelDimensionMap(); TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == - "blockDim.x"); + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); at::Tensor input2 = at::randn({15, 13}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); auto ref = (input1 + 1).unsqueeze(0) + input2; @@ -17543,7 +17582,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) { testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) { +TEST_F(NVFuserTest, FusionParallelDimensionMap5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17570,18 +17609,17 @@ TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) { TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); TORCH_CHECK( pdmap.get(ParallelType::TIDx)->isConst() && - pdmap.get(ParallelType::TIDx)->as()->value().value() == 4); + pdmap.get(ParallelType::TIDx)->as()->value().value() == 4); TORCH_CHECK( - pdmap.get(ParallelType::TIDy)->isA() && - pdmap.get(ParallelType::TIDy)->as()->name() == - "blockDim.y"); + pdmap.get(ParallelType::TIDy)->isA() && + pdmap.get(ParallelType::TIDy)->as()->name() == "blockDim.y"); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({13}, options); at::Tensor input2 = at::randn({13, 15}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {input1, input2}); auto outputs = fe.runFusion({input1, input2}); auto ref = (input1).unsqueeze(-1) + input2; @@ -17589,7 +17627,7 @@ TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) { testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { +TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -17603,7 +17641,7 @@ TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { auto t13 = makeSymbolicTensor(3, DataType::Half); auto t15 = makeSymbolicTensor(3, DataType::Half); auto t17 = makeSymbolicTensor(3, DataType::Half); - auto d56 = new Double(); + auto d56 = IrBuilder::create(); fusion.addInput(t0); fusion.addInput(t1); @@ -17636,9 +17674,10 @@ TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { auto t29 = mul(t25, t23); auto t30 = sum(t29, {2}); auto t31 = broadcast(t30, {false, false, true}); - auto d59 = mul(t1->getRootDomain()[2]->extent(), new Double(1)); + auto d59 = + mul(t1->getRootDomain()[2]->extent(), IrBuilder::create(1)); auto t26 = mul(d59, t25); - auto txx = mul(t26, new Double(1)); + auto txx = mul(t26, IrBuilder::create(1)); auto t33 = sub(txx, t28); auto d70 = unaryOp(UnaryOpType::Reciprocal, d59); auto t35 = mul(d70, t6); @@ -17694,23 +17733,23 @@ TEST(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { } } -TEST(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { +TEST_F(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); - auto tv3 = add(tv0, new Double(1)); - auto tv4 = add(tv3, new Double(1)); + auto tv3 = add(tv0, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); fusion.addOutput(tv4); - auto tv5 = add(tv0, new Double(1)); - auto tv6 = add(tv5, new Double(1)); + auto tv5 = add(tv0, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); fusion.addOutput(tv6); // Case 1: local memory tensor computed serially and used by @@ -17732,13 +17771,13 @@ TEST(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { tv5->axis(-1)->parallelize(ParallelType::TIDx); tv5->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int nx = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({nx}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = t0 + 2; @@ -17748,16 +17787,16 @@ TEST(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { } // Repro of issue #1105 -TEST(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { +TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -17783,12 +17822,12 @@ TEST(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { } } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 3; @@ -17796,24 +17835,24 @@ TEST(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1099_CUDA) { +TEST_F(NVFuserTest, FusionIssue1099_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); auto tv3 = makeSymbolicTensor(1); fusion.addInput(tv3); // Just to make TIDx/y/z non-exact - auto tv4 = add(tv3, new Double(1)); - auto tv5 = add(tv4, new Double(1)); - auto tv6 = add(tv5, new Double(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); fusion.addOutput(tv6); tv2->split(0, 4); @@ -17835,13 +17874,13 @@ TEST(NVFuserTest, FusionIssue1099_CUDA) { tv6->split(0, 7); tv6->axis(-1)->parallelize(ParallelType::TIDz); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t3 = at::randn({19}, options); std::vector aten_inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref_t2 = t0 + 2; @@ -17852,18 +17891,15 @@ TEST(NVFuserTest, FusionIssue1099_CUDA) { } // Repro of issue #1080 -TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionUnswitchPredicate_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv2->split(0, 4); @@ -17883,14 +17919,14 @@ TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int nx = 4; const int ny = 10; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({nx, ny}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = t0 + 2; @@ -17898,7 +17934,7 @@ TEST(NVFuserTest, FusionUnswitchPredicate_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1189_CUDA) { +TEST_F(NVFuserTest, FusionIssue1189_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17926,12 +17962,12 @@ TEST(NVFuserTest, FusionIssue1189_CUDA) { parallelize(tv2); parallelize(tv3); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({16, 16, 1}, options); at::Tensor t1 = at::randn({16, 16, 1}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); auto outputs = fe.runFusion({t0, t1}); auto ref = (t0 + t1).sum({1}); @@ -17939,7 +17975,7 @@ TEST(NVFuserTest, FusionIssue1189_CUDA) { testValidate(&fusion, outputs, {t0, t1}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1052_CUDA) { +TEST_F(NVFuserTest, FusionIssue1052_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -17948,10 +17984,10 @@ TEST(NVFuserTest, FusionIssue1052_CUDA) { auto tv1 = makeSymbolicTensor(1); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv2); - auto tv3 = add(tv1, new Double(1)); + auto tv3 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv3); tv2->axis(-1)->parallelize(ParallelType::TIDx); @@ -17960,13 +17996,13 @@ TEST(NVFuserTest, FusionIssue1052_CUDA) { scheduler_utils::parallelizeAllLike(tv2, {tv0}); scheduler_utils::parallelizeAllLike(tv3, {tv1}); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10}, options); at::Tensor t1 = at::randn({100}, options); std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref_t2 = t0 + 1; @@ -17977,7 +18013,7 @@ TEST(NVFuserTest, FusionIssue1052_CUDA) { } // Repro of issue #1115 -TEST(NVFuserTest, FusionPointwiseBroadcast_CUDA) { +TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18002,7 +18038,7 @@ TEST(NVFuserTest, FusionPointwiseBroadcast_CUDA) { schedulePointwise(&fusion, aten_inputs); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto at_x_add_bias = at_x + at_bias; @@ -18012,23 +18048,23 @@ TEST(NVFuserTest, FusionPointwiseBroadcast_CUDA) { testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionSmemAliasSerial_CUDA) { +TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); // Just set the dimension of TIDx auto tv4 = makeSymbolicTensor(1); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); tv1->setMemoryType(MemoryType::Shared); @@ -18040,14 +18076,13 @@ TEST(NVFuserTest, FusionSmemAliasSerial_CUDA) { // TIDx. They should be predicated as they are redundant and can // interfere with smem aliasing (issue #1100). - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10}, options); - at::Tensor t4 = at::randn({1024}, options); std::vector aten_inputs = {t0, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 3; @@ -18056,14 +18091,14 @@ TEST(NVFuserTest, FusionSmemAliasSerial_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { +TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); auto tv2 = makeSymbolicTensor(1); @@ -18074,13 +18109,13 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { tv1->axis(0)->parallelize(ParallelType::TIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t2 = at::randn({19}, options); std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 1; @@ -18089,14 +18124,14 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); auto tv2 = makeSymbolicTensor(1); @@ -18107,13 +18142,13 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { tv1->axis(0)->parallelize(ParallelType::TIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t2 = at::randn({19}, options); std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 1; @@ -18122,7 +18157,7 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { +TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18134,12 +18169,12 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { auto tv2 = makeSymbolicTensor(3); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); auto tv4 = makeSymbolicTensor(3); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); tv1->axis(0)->parallelize(ParallelType::BIDx); @@ -18175,7 +18210,7 @@ TEST(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { #endif } -TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { +TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18187,12 +18222,12 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { auto tv2 = makeSymbolicTensor(3); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); auto tv4 = makeSymbolicTensor(3); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); tvs.avg->axis(0)->parallelize(ParallelType::BIDx); @@ -18229,7 +18264,7 @@ TEST(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { } // Repro of issue #1102 -TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { +TEST_F(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18237,18 +18272,18 @@ TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { fusion.addInput(tv0); // Just to make TIDx/y/z non-exact - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); auto tv4 = makeSymbolicTensor(1); fusion.addInput(tv4); - auto tv5 = add(tv4, new Double(1)); - auto tv6 = add(tv5, new Double(1)); - auto tv7 = add(tv6, new Double(1)); - auto tv8 = add(tv7, new Double(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + auto tv7 = add(tv6, IrBuilder::create(1)); + auto tv8 = add(tv7, IrBuilder::create(1)); auto tv9 = sum(tv8, {0}); fusion.addOutput(tv9); @@ -18274,13 +18309,13 @@ TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { tv5->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t4 = at::randn({19}, options); std::vector aten_inputs = {t0, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 3; @@ -18290,8 +18325,9 @@ TEST(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { } // Repro of #1102 and #1129 -TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } Fusion fusion; @@ -18302,16 +18338,16 @@ TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { auto tv1 = makeSymbolicTensor(1); fusion.addInput(tv1); - auto tv2 = add(tv0, new Double(1)); - auto tv3 = add(tv2, new Double(1)); - auto tv4 = add(tv3, new Double(1)); - auto tv5 = add(tv4, new Double(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); fusion.addOutput(tv5); // Just to make TIDx/y/z non-exact - auto tvx = add(tv1, new Double(1)); - auto tvy = add(tvx, new Double(1)); - auto tvz = add(tvy, new Double(1)); + auto tvx = add(tv1, IrBuilder::create(1)); + auto tvy = add(tvx, IrBuilder::create(1)); + auto tvz = add(tvy, IrBuilder::create(1)); fusion.addOutput(tvz); tv5->split(0, 4); @@ -18335,13 +18371,13 @@ TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { tv->setMemoryType(MemoryType::Shared); } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({17}, options); at::Tensor t1 = at::randn({19}, options); std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref1 = t0 + 4; @@ -18351,21 +18387,24 @@ TEST(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { } // Repro of issue #1136 -TEST(NVFuserTest, FusionFloatPow_CUDA) { +TEST_F(NVFuserTest, FusionFloatPow_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = binaryOp(BinaryOpType::Pow, tv0, new Int(4)); + auto tv1 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(4)); // To check if pow(tv0, 2) is replaced with tv0 * tv0 - auto tv2 = binaryOp(BinaryOpType::Pow, tv0, new Int(2)); + auto tv2 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); // To check if pow(tv0, 2.0) is replaced with tv0 * tv0 - auto tv3 = binaryOp(BinaryOpType::Pow, tv0, new Double(2)); - auto tv4 = binaryOp(BinaryOpType::Pow, tv0, new Int(3)); - auto tv5 = binaryOp(BinaryOpType::Pow, tv0, new Double(3)); - auto s = binaryOp(BinaryOpType::Pow, new Double(3), new Double(3)); + auto tv3 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); + auto tv4 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); + auto tv5 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); + auto s = binaryOp( + BinaryOpType::Pow, + IrBuilder::create(3), + IrBuilder::create(3)); auto tv6 = add(tv0, s); fusion.addOutput(tv1); @@ -18382,14 +18421,14 @@ TEST(NVFuserTest, FusionFloatPow_CUDA) { TransformPropagator::from(tv1); scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); // Negative inputs cause nan in Fuesr as use_fast_math is enabled t0 = abs(t0); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto p4 = at::pow(t0, 4); @@ -18406,7 +18445,7 @@ TEST(NVFuserTest, FusionFloatPow_CUDA) { __FILE__); } -TEST(NVFuserTest, FusionIssue1127_CUDA) { +TEST_F(NVFuserTest, FusionIssue1127_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18435,7 +18474,7 @@ TEST(NVFuserTest, FusionIssue1127_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -TEST(NVFuserTest, FusionChannelsLastParser_CUDA) { +TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // This test may not pass if using a custom block sync as there may // be additional calls. Skip the test as it's not specifically // relevant with block synchronizatin. @@ -18486,30 +18525,30 @@ TEST(NVFuserTest, FusionChannelsLastParser_CUDA) { const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { if ((((((((((nvfuser_index_t)blockIdx.x) * 1) + 0) * 1) + 0) * 128) + ((nvfuser_index_t)threadIdx.x)) < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { - constexpr nvfuser_index_t ki674 = 0; + constexpr nvfuser_index_t i120 = 0; __half T9[1]; - constexpr nvfuser_index_t ki716 = 0; - T9[ki716] = 0; - constexpr nvfuser_index_t ki707 = 0; - T9[ki707] - = T2[((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * (((1 * T0.size[2]) * T0.size[1]) * T0.size[3])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * ((1 * T0.size[2]) * T0.size[1])) + (((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * (1 * T0.size[2])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki707) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3]) * 1)]; + constexpr nvfuser_index_t i132 = 0; + T9[i132] = 0; + constexpr nvfuser_index_t i128 = 0; + T9[i128] + = T2[((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * (((1 * T0.size[2]) * T0.size[1]) * T0.size[3])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * ((1 * T0.size[2]) * T0.size[1])) + (((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * (1 * T0.size[2])) + ((((((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i128) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3]) * 1)]; __half T8[1]; - constexpr nvfuser_index_t ki722 = 0; - T8[ki722] = 0; - constexpr nvfuser_index_t ki702 = 0; - T8[ki702] - = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki702) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; + constexpr nvfuser_index_t i134 = 0; + T8[i134] = 0; + constexpr nvfuser_index_t i126 = 0; + T8[i126] + = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i126) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]; __half T10[1]; - constexpr nvfuser_index_t ki683 = 0; + constexpr nvfuser_index_t i124 = 0; float T3[1]; T3[0] - = __half2float(T9[ki683]); + = __half2float(T9[i124]); float T4[1]; T4[0] = T3[0]; float T1[1]; T1[0] - = __half2float(T8[ki683]); + = __half2float(T8[i124]); float T5[1]; T5[0] = T1[0] @@ -18517,11 +18556,11 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, float T6[1]; T6[0] = relu(T5[0]); - T10[ki683] + T10[i124] = __float2half(T6[0]); - constexpr nvfuser_index_t ki676 = 0; - T7[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki674) * 1) + ki676) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] - = T10[ki676]; + constexpr nvfuser_index_t i122 = 0; + T7[(((((((((nvfuser_index_t)blockIdx.x) * 1) + i120) * 1) + i122) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)] + = T10[i122]; } } )"; @@ -18558,7 +18597,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, // TORCH_CHECK(output_ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { +TEST_F(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18566,8 +18605,8 @@ TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, new Double(1)); - auto tv3 = add(tv2, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -18575,12 +18614,12 @@ TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { tv2->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::Unswitch); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10, 1024}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = sum(t0, {1}) + 2; @@ -18588,24 +18627,24 @@ TEST(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonContigOutputs_CUDA) { +TEST_F(NVFuserTest, FusionNonContigOutputs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); fusion.addOutput(tv1); tv1->setContiguity(false); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor at_input = at::randn({10}, options); at::Tensor at_output = at::empty_strided({10}, {2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_input}); auto returned_outputs = fe.runFusion({at_input}, {at_output}); // Returned outputs should only contain one tensor that is the same @@ -18619,7 +18658,7 @@ TEST(NVFuserTest, FusionNonContigOutputs_CUDA) { testValidate(&fusion, {at_output}, {at_input}, {at_ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionTestWarpSoftMax_CUDA) { +TEST_F(NVFuserTest, FusionTestWarpSoftMax_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18654,14 +18693,15 @@ TEST(NVFuserTest, FusionTestWarpSoftMax_CUDA) { // Test result FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref_output = at::_softmax(aten_input, 1, false); testValidate(&fusion, outputs, aten_inputs, {ref_output}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1133_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionIssue1133_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } Fusion fusion; @@ -18670,9 +18710,9 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -18702,20 +18742,20 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { // There should be no allocation other than those for tv1 and tv2 TORCH_CHECK(false, "Invalid allocation detected"); } - TORCH_CHECK(size->isA(), "Invalid allocation size"); - TORCH_CHECK(size->as()->isConst(), "Allocation not constant"); - auto size_int = size->as()->value().value(); + TORCH_CHECK(size->isA(), "Invalid allocation size"); + TORCH_CHECK(size->as()->isConst(), "Allocation not constant"); + auto size_int = size->as()->value().value(); if (alloc->buffer()->name() == 1) { TORCH_CHECK( size_int == split_factor, "Invalid allocation size: ", - size->as()->value().value()); + size->as()->value().value()); tv1_validated = true; } else { TORCH_CHECK( size_int == 1, "Invalid allocation size: ", - size->as()->value().value()); + size->as()->value().value()); tv2_validated = true; } } @@ -18724,12 +18764,12 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { TORCH_CHECK(tv1_validated, "Failed to validate tv1 allocation"); TORCH_CHECK(tv2_validated, "Failed to validate tv2 allocation"); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({99, 101}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = (t0 + 1).sum({1}) + 1; @@ -18737,7 +18777,7 @@ TEST(NVFuserTest, FusionIssue1133_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionRfactorContigIDs_CUDA) { +TEST_F(NVFuserTest, FusionRfactorContigIDs_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18756,12 +18796,12 @@ TEST(NVFuserTest, FusionRfactorContigIDs_CUDA) { tv2->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({99, 101}, options); std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); auto outputs = fe.runFusion(aten_inputs); auto ref = t0.sum({1}); @@ -18769,7 +18809,7 @@ TEST(NVFuserTest, FusionRfactorContigIDs_CUDA) { testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18831,7 +18871,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { aten_t0.size(1) * dataTypeSize(DataType::Float)); } -TEST(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18894,7 +18934,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { aten_t0.size(1) * dataTypeSize(DataType::Half)); } -TEST(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -18979,7 +19019,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { (dataTypeSize(DataType::Half) + dataTypeSize(DataType::Float))); } -TEST(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { +TEST_F(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19056,10 +19096,7 @@ TEST(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { aten_t0.size(1) * dataTypeSize(DataType::Half)); } -TEST(NVFuserTest, PersistentBufferProjection_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionPersistentBufferProjection_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -19107,8 +19144,9 @@ TEST(NVFuserTest, PersistentBufferProjection_CUDA) { testValidate(&fusion, cg_outputs, {aten_t0}, {aten_t7}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionIssue1223_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 7) { +TEST_F(NVFuserTest, FusionIssue1223_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; return; } Fusion fusion; @@ -19117,11 +19155,11 @@ TEST(NVFuserTest, FusionIssue1223_CUDA) { auto tv0 = makeContigTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0, 1}); fusion.addOutput(tv2); - auto tv3 = add(tv0, new Double(0)); + auto tv3 = add(tv0, IrBuilder::create(0)); fusion.addOutput(tv3); tv2->split(0, 4); @@ -19153,7 +19191,7 @@ TEST(NVFuserTest, FusionIssue1223_CUDA) { at::Tensor at_t0 = at::ones({11, 10}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {at_t0}); auto cg_outputs = fe.runFusion({at_t0}); auto at_t1 = (at_t0 + 1).sum(); @@ -19163,14 +19201,14 @@ TEST(NVFuserTest, FusionIssue1223_CUDA) { } // See #1247 and #1250 -TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { +TEST_F(NVFuserTest, FusionRfactorPredication1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = min(tv1, {0}); fusion.addOutput(tv2); @@ -19179,7 +19217,7 @@ TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { auto tv3 = makeContigTensor(1); fusion.addInput(tv3); - auto tv4 = add(tv3, new Double(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); fusion.addOutput(tv4); tv2->split(0, 4); @@ -19197,7 +19235,7 @@ TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { at::Tensor at_t3 = at::randn({128}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {at_t0, at_t3}); auto cg_outputs = fe.runFusion({at_t0, at_t3}); auto at_t2 = (at_t0 + 1).min(); @@ -19207,7 +19245,7 @@ TEST(NVFuserTest, FusionRfactorPredication1_CUDA) { &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { +TEST_F(NVFuserTest, FusionRfactorPredication2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19221,7 +19259,7 @@ TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { auto tv2 = makeContigTensor(1); fusion.addInput(tv2); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); tv1->split(0, 4); @@ -19250,7 +19288,7 @@ TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { at::Tensor at_t3 = at::randn({128}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {at_t0, at_t3}); auto cg_outputs = fe.runFusion({at_t0, at_t3}); auto at_t2 = std::get<0>(at_t0.min(0)); @@ -19260,7 +19298,7 @@ TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19292,7 +19330,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, "Only tv1 should have a non-divisible predicate."); - for (auto tv : {tv1}) { + for (auto tv : {loweredTv(tv1, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19309,7 +19347,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = t0.sum(); @@ -19318,14 +19356,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { } // Repro of issue #1074 -TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); tv2->split(0, 2); @@ -19346,7 +19384,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, "Only tv2 should have a non-divisible predicate."); - for (auto tv : {tv2}) { + for (auto tv : {loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19363,7 +19401,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { at::Tensor t0 = at::randn({13, 17}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = t0 + 2; @@ -19372,14 +19410,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { } // Similar to FusionNonDivisibleSplit1 but with unswitch -TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0}); fusion.addOutput(tv2); @@ -19397,7 +19435,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {tv1, tv2}) { + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19414,7 +19452,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19423,14 +19461,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { } // Non-divisible split through merge -TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0, 1}); fusion.addOutput(tv2); @@ -19447,7 +19485,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {tv1, tv2}) { + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19464,7 +19502,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { at::Tensor t0 = at::randn({24, 2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19473,14 +19511,14 @@ TEST(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { } // Nested splits -TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {0}); fusion.addOutput(tv2); @@ -19501,7 +19539,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { TORCH_CHECK( gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {tv1, tv2}) { + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); TORCH_CHECK( it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), @@ -19518,7 +19556,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19527,7 +19565,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { } // Vectorized non-divisible split. Must be validated at run time -TEST(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19556,13 +19594,12 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { splits_to_predicate); } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); - auto t0 = at::randn({32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = t0; @@ -19576,7 +19613,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { } // If a split is validated at run time, it's not necessary to predicate. -TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -19584,7 +19621,7 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { fusion.addInput(tv0); auto tv1 = set(tv0); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = sum(tv2, {0}); fusion.addOutput(tv3); @@ -19611,13 +19648,13 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { splits_to_predicate); } - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); auto t0 = at::randn({1024}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = (t0 + 1).sum(); @@ -19625,6 +19662,784 @@ TEST(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionIssue1284Repro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector input_shape_0 = {10, 20}; + std::vector input_shape_1 = {15}; + + TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); + TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); + fusion.addInput(in_0); + fusion.addInput(in_1); + + TensorView* out_0 = add(in_0, IrBuilder::create(0.f)); + TensorView* out_1 = add(in_1, IrBuilder::create(2.f)); + + fusion.addOutput(out_0); + fusion.addOutput(out_1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_in_0 = at::randn(input_shape_0, options); + at::Tensor at_in_1 = at::randn(input_shape_1, options); + std::vector aten_inputs = {at_in_0, at_in_1}; + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + auto t1 = at_in_1 + 2; + + auto runtime = fec.getMostRecentKernelRuntime(); + TORCH_INTERNAL_ASSERT(runtime->isSegmented()); + TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); + + testValidate( + &fusion, outputs, {at_in_0, at_in_1}, {at_in_0, t1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1284Repro2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector input_shape_0 = {4, 4}; + std::vector input_shape_1 = {3, 4, 4}; + std::vector input_shape_2 = {2, 8, 4, 4}; + + TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); + TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); + TensorView* in_2 = makeSymbolicTensor(input_shape_2.size()); + + fusion.addInput(in_0); + fusion.addInput(in_1); + fusion.addInput(in_2); + + TensorView* out_0 = add(in_0, in_1); + TensorView* out_1 = add(in_0, in_2); + + fusion.addOutput(out_0); + fusion.addOutput(out_1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_in_0 = at::randn(input_shape_0, options); + at::Tensor at_in_1 = at::randn(input_shape_1, options); + at::Tensor at_in_2 = at::randn(input_shape_2, options); + + std::vector aten_inputs = {at_in_0, at_in_1, at_in_2}; + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + auto t0 = at_in_0 + at_in_1; + auto t1 = at_in_0 + at_in_2; + + auto runtime = fec.getMostRecentKernelRuntime(); + TORCH_INTERNAL_ASSERT(runtime->isSegmented()); + TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); + + testValidate( + &fusion, + outputs, + {at_in_0, at_in_1, at_in_2}, + {t0, t1}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1305Repro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto t0 = makeContigTensor(1); + auto t1 = makeContigTensor(2); + + fusion.addInput(t0); + fusion.addInput(t1); + + auto t2 = broadcast(t0, {true, false}); + auto t3 = add(t1, t2); + auto t4 = add(t3, t2); + auto t5 = sum(t4, {1}); + auto t6 = broadcast(t5, {false, true}); + auto t7 = add(t3, t6); + + fusion.addOutput(t7); + + t3->computeAt(t7, -1, ComputeAtMode::MostInlined); + + TORCH_INTERNAL_ASSERT(t3->getComputeAtPosition() == 1); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 1); + + tv3->axis(-2)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, -1); + + tv3->axis(-2)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 1); + + // tv2 is invalid to double-buffer as its producer, tv1, is + // computed inside the double-buffering loop. + ASSERT_ANY_THROW(tv2->doubleBuffer()); + + // Moving tv2 inner makes tv1 large enough to double-buffer tv2 + tv2->computeAt(tv3, 2); + + tv2->doubleBuffer(); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering smem to local and unswitch +TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + tv3->split(-1, 8); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 2); + tv2->computeAt(tv3, -1); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::Unswitch); + scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + + tv2->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering gmem to shared and unswitch +TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + + tv2->split(-1, 128); + tv2->split(-1, 32); + tv2->split(-1, 8); + TransformPropagator::from(tv2); + + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, -1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Unswitch); + scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering smem to local and unroll +TEST_F(NVFuserTest, FusionDoubleBuffering6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 16); + tv3->split(-2, 4); + tv3->split(-2, 2); + TransformPropagator::from(tv3); + + tv0->computeAt(tv3, 1); + tv2->computeAt(tv3, -1); + + tv3->axis(2)->parallelize(ParallelType::Unroll); + tv3->axis(4)->parallelize(ParallelType::TIDx); + + tv2->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({199}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering and vectorize +TEST_F(NVFuserTest, FusionDoubleBuffering7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv2->split(-1, 128); + tv2->split(-1, 4); + TransformPropagator::from(tv2); + + tv1->computeAt(tv2, 2); + + tv2->axis(-2)->parallelize(ParallelType::TIDx); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({200}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Multiple tensors to double-buffer +TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(1); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv4->split(0, 32); + tv4->split(0, 4); + TransformPropagator::from(tv4); + + tv0->computeAt(tv4, 1); + tv1->computeAt(tv4, 1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion)); + + tv2->doubleBuffer(); + tv3->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({100}, options); + auto t1 = at::randn({100}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Nested double buffering from gmem to smem and smem to register +TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto out = tv1; + fusion.addOutput(out); + + auto tv2 = tv0->cache_after(); + auto tv3 = tv2->cache_after(); + + out->split(0, 32); + out->split(0, 4); + TransformPropagator::from(out); + + tv2->setMemoryType(MemoryType::Shared); + + tv2->computeAt(out, 1); + tv3->computeAt(out, -1); + + out->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + tv2->doubleBuffer(); + tv3->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1001}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// FusionSmemBlockGemmCache + double buffering at both smem and local +TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(2); // (M, K) + TensorView* tv1 = makeSymbolicTensor(2); // (K, N) + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + TensorView* tv6 = tv5->cache_before(); + + // For smem double buffering + auto tv0_cache_local = tv0->cache_after(); + auto tv1_cache_local = tv1->cache_after(); + + // For register double buffering + auto tv0_cache_smem = tv0->cache_after(); + auto tv1_cache_smem = tv1->cache_after(); + + const int BSX = 32; + const int TSX = 8; + + // [M, K, N] + tv6->split(-1, BSX); + tv6->split(-1, TSX); + tv6->split(1, BSX); + tv6->split(0, BSX); + tv6->split(1, TSX); + // [M/BSX, BSX/TSX, TSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->reorder( + {{4, 7}, {7, 6}, {6, 5}, {2, 4}, {1, 3}, {3, 2}, {5, 1}, {0, 0}}); + // [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + + auto tv6_rf = tv6->rFactor({-1}); + + TransformPropagator::from(tv6_rf); + + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + + tv6_rf->computeAt(tv6, -1); + tv0_cache_local->computeAt(tv6_rf, -1); + tv1_cache_local->computeAt(tv6_rf, -1); + + tv0_cache_smem->setMemoryType(MemoryType::Shared); + tv1_cache_smem->setMemoryType(MemoryType::Shared); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-3)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); + + tv0_cache_local->doubleBuffer(); + tv1_cache_local->doubleBuffer(); + + tv0_cache_smem->doubleBuffer(); + tv1_cache_smem->doubleBuffer(); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { + auto mem_types = {MemoryType::Shared, MemoryType::Local}; + + for (auto mem_type : mem_types) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(mem_type); + + tv3->split(-1, 4); + TransformPropagator::from(tv3); + + tv1->computeAt(tv3, -2); + + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({15}, options); + FusionExecutor fe; + fe.compileFusion(&fusion); + + // This should throw an exception as the extent of t0 is not + // divisible by the vector width + ASSERT_ANY_THROW(fe.runFusion({t0})); + + auto t1 = at::randn({16}, options); + auto cg_outputs = fe.runFusion({t1}); + + auto ref = t1; + + testValidate(&fusion, cg_outputs, {t1}, {ref}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({10, 1}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({10, 20}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({10, 10}); + fusion.addInput(tv2); + + // Not concretized + auto tv3 = sum(tv2, {1}); + auto tv4 = broadcast(tv3, {false, true}); + auto tv5 = add(tv0, tv4); + fusion.addOutput(tv5); + + // Concretized + auto tv6 = sum(tv2, {1}); + auto tv7 = broadcast(tv6, {false, true}); + auto tv8 = add(tv1, tv7); + fusion.addOutput(tv8); + + for (auto tv : {tv3, tv4, tv5, tv6, tv7, tv8}) { + tv->axis(1)->parallelize(ParallelType::TIDx); + } + + GpuLower gpulw(&fusion); + TORCH_CHECK(!gpulw.concretizedBroadcastDomains().isConcretized( + loweredTv(tv4, gpulw)->axis(1))); + TORCH_CHECK(gpulw.concretizedBroadcastDomains().isConcretized( + loweredTv(tv7, gpulw)->axis(1))); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({10, 1}, options); + auto t1 = at::randn({10, 20}, options); + auto t2 = at::randn({10, 10}, options); + std::vector aten_inputs = {t0, t1, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t5 = t0 + t2.sum({1}).unsqueeze(-1); + auto t8 = t1 + t2.sum({1}).unsqueeze(-1); + + testValidate(&fusion, outputs, aten_inputs, {t5, t8}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 1}); + auto tv2 = broadcast(tv1, {true}); + auto tv3 = broadcast(tv2, {false, true}); + fusion.addOutput(tv3); + + // tv1 is thread-predicated with TIDx and TIDy + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::TIDy); + // tv2 broadcasts along TIDx + tv2->axis(0)->parallelize(ParallelType::TIDx); + // tv3 broadcasts along TIDy + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDy); + + // Both tv2 and tv3 broadcast along predicated TID dimensions, but + // since the broadcast domains are not concretized, there should be + // no actual parallel broadcast + + GpuLower gpulw(&fusion); + TORCH_CHECK( + !gpulw.kernel()->summary().has_block_broadcasts && + !gpulw.kernel()->summary().has_grid_broadcasts, + "There must be no parallel broadcast in this fusion"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({10, 11}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t3 = t0.sum().unsqueeze(-1).unsqueeze(-1); + + testValidate(&fusion, outputs, aten_inputs, {t3}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape({10, 4, 8}); + std::vector output_shape({8, 4, 1}); + + auto tv0 = makeConcreteTensor(input_shape); + fusion.addInput(tv0); + + auto tv2 = sum(tv0, {0}); + auto tv3 = set(tv2); + auto tv4 = + view(tv3, {input_shape.begin() + 1, input_shape.end()}, output_shape); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + // The view op adds a broadcast domain in tv4, which is + // parallelized. Howver, it is never materialized, so there should + // be no parallel broadcast. + + GpuLower gpulw(&fusion); + TORCH_CHECK( + !gpulw.kernel()->summary().has_block_broadcasts && + !gpulw.kernel()->summary().has_grid_broadcasts, + "There must be no parallel broadcast in this fusion"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(input_shape, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t5 = at::native::view(t0.sum(0), output_shape) + 1; + + testValidate(&fusion, outputs, aten_inputs, {t5}, __LINE__, __FILE__); +} + +// Merging non-broadcast and broadcast domains +// TODO: Fix use case see issue https://github.com/csarofeen/pytorch/issues/1418 +// validateParallelize does not pass. Even if it's skipped, +// generated code is invalid as blockBroadcast is not used. +#if 0 +TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + fusion.addOutput(tv3); + + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv2->merge(0, 1); + tv2->axis(0)->parallelize(ParallelType::TIDx); + // TODO: When set to shared memory, this kernel should be correct, but fails + // validation and when skipped produces incorrect code + tv2->setMemoryType(MemoryType::Shared); + + tv3->merge(0, 1); + tv3->axis(0)->parallelize(ParallelType::TIDx); + + fusion.printMath(); + fusion.printKernel(); +} +#endif + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/test/cpp/jit/test_gpu_shift.cpp b/test/cpp/jit/test_gpu_shift.cpp index 71fa156c2d2..2665f16563b 100644 --- a/test/cpp/jit/test_gpu_shift.cpp +++ b/test/cpp/jit/test_gpu_shift.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -18,8 +19,6 @@ #include #include #include -#include -#include #include #include #include @@ -82,33 +81,38 @@ void checkIntValue( void checkIntValue( kir::ExpressionEvaluator& evaluator, - const kir::Val* val, - kir::Int::ScalarType expected_value) { + const Val* val, + Int::ScalarType expected_value) { const auto actual_value = evaluator.evaluate(val); TORCH_CHECK(actual_value.has_value()); TORCH_CHECK(actual_value.value() == expected_value); } +// Used to signify invalid ranges, i.e., values at offset 0 to +// start_offset, and values at offset stop_offset to the end of the +// domain. +static constexpr int invalid_marker = 1; + // ATen version of tensor shifting auto shift( at::Tensor tensor, const std::vector& offsets, - std::vector strides = {}) { + std::vector padding = {}) { TORCH_INTERNAL_ASSERT(tensor.ndimension() == offsets.size()); - if (strides.empty()) { - strides = std::vector(tensor.ndimension(), 1); + if (padding.empty()) { + padding = offsets; + for (auto& p : padding) { + p = std::abs(p); + } } at::Tensor t = tensor; - std::vector stride_indices; for (size_t i = 0; i < offsets.size(); ++i) { - auto stride = strides[i]; - stride_indices.push_back( - at::indexing::Slice(0, at::indexing::None, stride)); - const auto offset = offsets[i]; + auto offset = offsets[i]; + t = t.roll(offsets[i], i); if (offset == 0) { continue; } - t = t.roll(offsets[i], i); + // Zero padding std::vector indices( tensor.ndimension(), at::indexing::Slice(0, at::indexing::None)); if (offset > 0) { @@ -117,8 +121,20 @@ auto shift( indices[i] = at::indexing::Slice(offset, at::indexing::None); } t.index(indices) = 0; + // Fill the outside range by the special marker value. + const auto pad = padding[i]; + if (offset > 0) { + indices[i] = at::indexing::Slice(0, offset - pad); + } else { + offset += pad; + TORCH_INTERNAL_ASSERT(offset <= 0); + if (offset == 0) { + continue; + } + indices[i] = at::indexing::Slice(offset, at::indexing::None); + } + t.index(indices) = invalid_marker; } - t = t.index(stride_indices); return t; } @@ -153,13 +169,28 @@ auto gather( TORCH_CHECK(w_size != 0); const auto& pad = pad_width[i]; TORCH_CHECK(pad.size() == 2); + const auto out_extent_adj = -w_size + 1 + pad[0] + pad[1]; + TORCH_INTERNAL_ASSERT(out_extent_adj <= 0); + const auto stride = strides[i]; + TORCH_CHECK(stride >= 1); + at::Tensor concat_tensor; + for (int w = 0; w < w_size; ++w) { std::vector shift_offsets(t.ndimension(), 0); shift_offsets[i] = pad[0] - w; - std::vector shift_strides(t.ndimension(), 1); - shift_strides[i] = strides[i]; - auto shifted = shift(t, shift_offsets, shift_strides); + auto shifted = shift(t, shift_offsets); + // Apply stride + if (stride != 1) { + std::vector indices( + shifted.ndimension(), at::indexing::Slice(0, at::indexing::None)); + if (out_extent_adj == 0) { + indices[i] = at::indexing::Slice(0, at::indexing::None, strides[i]); + } else { + indices[i] = at::indexing::Slice(0, out_extent_adj, strides[i]); + } + shifted = shifted.index(indices); + } shifted = shifted.unsqueeze(-1); if (w == 0) { concat_tensor = shifted; @@ -169,13 +200,32 @@ auto gather( } t = concat_tensor; } + + // Fill invalid regions with the marker. Note that when non-unit + // stride is used, it trims invalid regions, so no marking is + // necessary. + for (size_t i = 0; i < window_shape.size(); ++i) { + if (strides[i] != 1) { + continue; + } + + const auto out_extent_adj = + -window_shape[i] + 1 + pad_width[i][0] + pad_width[i][1]; + if (out_extent_adj < 0) { + std::vector indices( + t.ndimension(), at::indexing::Slice(0, at::indexing::None)); + indices[i] = at::indexing::Slice(out_extent_adj, at::indexing::None); + t.index(indices) = invalid_marker; + } + } + return t; } } // namespace // Shift an input tensor -TEST(NVFuserTest, FusionShift1_CUDA) { +TEST_F(NVFuserTest, FusionShift1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -202,7 +252,7 @@ TEST(NVFuserTest, FusionShift1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = shift(t0, {-1, 0}); @@ -219,19 +269,19 @@ TEST(NVFuserTest, FusionShift1_CUDA) { } // Shifts an intermediate tensor -TEST(NVFuserTest, FusionShift2_CUDA) { +TEST_F(NVFuserTest, FusionShift2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {-1, 0}); fusion.addOutput(tv2); // make it a little more complex - auto tv3 = add(tv0, new Double(3)); - auto tv4 = add(tv3, new Double(4)); + auto tv3 = add(tv0, IrBuilder::create(3)); + auto tv4 = add(tv3, IrBuilder::create(4)); auto tv5 = shift(tv4, {-1, 0}); auto tv6 = shift(tv4, {0, -1}); auto tv7 = shift(tv4, {1, 0}); @@ -250,21 +300,22 @@ TEST(NVFuserTest, FusionShift2_CUDA) { // t4 allocation: (t3.size[0] + 2) * (t3.size[1] + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 3 || tensor_name == 4) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { if (tensor_name == 1 && i == 1) { - TORCH_CHECK(alloc->shape().at(i)->isA()); + TORCH_CHECK(alloc->shape().at(i)->isA()); continue; } auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add); - TORCH_CHECK(def->as()->lhs()->isA()); - auto rhs = dynamic_cast(def->as()->rhs()); + dynamic_cast(alloc->shape().at(i)->definition()); + TORCH_CHECK( + def != nullptr && def->getBinaryOpType() == BinaryOpType::Add); + TORCH_CHECK(def->as()->lhs()->isA()); + auto rhs = dynamic_cast(def->as()->rhs()); TORCH_CHECK(rhs != nullptr && rhs->isConst()); int rhs_value = *rhs->value(); if (tensor_name == 1) { @@ -290,7 +341,7 @@ TEST(NVFuserTest, FusionShift2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -309,14 +360,14 @@ TEST(NVFuserTest, FusionShift2_CUDA) { testValidate(&fusion, outputs, inputs, {t2, t11}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { +TEST_F(NVFuserTest, FusionShiftRightOfCA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); fusion.addOutput(tv2); @@ -324,15 +375,15 @@ TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { tv1->setMemoryType(MemoryType::Global); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 100; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -341,16 +392,16 @@ TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) { TORCH_CHECK(t2.allclose(outputs[0])); } -TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) { +TEST_F(NVFuserTest, FusionShiftLeftOfCA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = shift(tv2, {-1, 0}); - auto tv4 = add(tv3, new Double(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); fusion.addOutput(tv4); tv0->computeAt(tv4, -1); @@ -360,13 +411,13 @@ TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) { ASSERT_ANY_THROW(fusion.printKernel()); } -TEST(NVFuserTest, FusionShiftSplit1_CUDA) { +TEST_F(NVFuserTest, FusionShiftSplit1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); auto tv3 = shift(tv1, {0, -2}); fusion.addOutput(tv2); @@ -379,35 +430,29 @@ TEST(NVFuserTest, FusionShiftSplit1_CUDA) { tv0->computeAt(tv2, -2); tv0->computeAt(tv3, -2); - // t1 allocation: (4 + 3) + // t1 allocation: 7 GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 3); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && size->value().value() == 7); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 9; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -417,23 +462,23 @@ TEST(NVFuserTest, FusionShiftSplit1_CUDA) { testValidate(&fusion, outputs, inputs, {t2, t3}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftSplit2_CUDA) { +TEST_F(NVFuserTest, FusionShiftSplit2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); auto tv3 = shift(tv2, {0, -1}); auto tv4 = shift(tv2, {0, 1}); auto tv5 = add(tv3, tv4); fusion.addOutput(tv5); - auto tv6 = add(tv0, new Double(1)); + auto tv6 = add(tv0, IrBuilder::create(1)); auto tv7 = shift(tv6, {0, 0}); - auto tv8 = add(tv7, new Double(1)); + auto tv8 = add(tv7, IrBuilder::create(1)); fusion.addOutput(tv8); int split_factor = 4; @@ -444,26 +489,20 @@ TEST(NVFuserTest, FusionShiftSplit2_CUDA) { tv0->computeAt(tv5, -2); tv0->computeAt(tv8, -2); - // t1 and t2 allocation: (4 + 2) - // t4 allocation: (4) + // t1 and t2 allocation: 6 + // t4 allocation: 4 GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && size->value().value() == 6); } else if (tensor_name == 4) { TORCH_CHECK(alloc->shape().size() == 1); - auto size = dynamic_cast(alloc->shape().at(0)); + auto size = dynamic_cast(alloc->shape().at(0)); TORCH_CHECK(size != nullptr && size->isConst()); int size_value = *size->value(); TORCH_CHECK(size_value == split_factor); @@ -471,15 +510,15 @@ TEST(NVFuserTest, FusionShiftSplit2_CUDA) { } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 9; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 2; @@ -494,14 +533,14 @@ TEST(NVFuserTest, FusionShiftSplit2_CUDA) { testValidate(&fusion, outputs, inputs, {t5, t8}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { +TEST_F(NVFuserTest, FusionShiftDoubleSplit_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {0, 1}); fusion.addOutput(tv3); @@ -518,35 +557,29 @@ TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { // t2: [i1, i2/8, 8] // t3: [i1, i2/8, 8] - // t1 and t2 allocation: (split_factor1 + 1) + // t1 and t2 allocation: (split_factor1 + 1) = 9 GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && size->value().value() == 9); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 3; @@ -555,7 +588,7 @@ TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { +TEST_F(NVFuserTest, FusionShift3ptStencil_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -576,7 +609,7 @@ TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -598,32 +631,29 @@ TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { // cache allocation: (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == cache->name()) { TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor + 2); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); + cache->doubleBuffer(); int numel_x = 99; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; @@ -631,7 +661,7 @@ TEST(NVFuserTest, FusionShift3ptStencil_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencil_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -651,7 +681,7 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -672,28 +702,22 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { // cache allocation: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == cache->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor[i] + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); + cache->doubleBuffer(); int numel_x = 99; int numel_y = 101; @@ -701,6 +725,9 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -712,7 +739,7 @@ TEST(NVFuserTest, FusionShift5ptStencil_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { +TEST_F(NVFuserTest, FusionShift9ptStencil_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -740,7 +767,7 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -763,28 +790,22 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { // cache allocation: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == cache->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor[i] + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); + cache->doubleBuffer(); int numel_x = 99; int numel_y = 101; @@ -792,6 +813,9 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -803,13 +827,13 @@ TEST(NVFuserTest, FusionShift9ptStencil_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { +TEST_F(NVFuserTest, FusionShiftSmemBlocking_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); fusion.addOutput(tv2); @@ -826,35 +850,30 @@ TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { // tv1 allocation: (split_factor + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == tv1->name()) { TORCH_CHECK(alloc->shape().size() == 1); for (int i = 0; i < 1; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == smem_block_factor && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == smem_block_factor + 1); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 100; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -864,7 +883,7 @@ TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { +TEST_F(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -881,7 +900,7 @@ TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -902,14 +921,16 @@ TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { tv_out->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); + tv0_cache->doubleBuffer(); int numel_x = 99; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3; @@ -917,7 +938,7 @@ TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -937,7 +958,7 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -965,15 +986,15 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-2)->parallelize(ParallelType::TIDy); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -985,13 +1006,13 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftMerge1_CUDA) { +TEST_F(NVFuserTest, FusionShiftMerge1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {-1, 1}); fusion.addOutput(tv2); @@ -1006,35 +1027,30 @@ TEST(NVFuserTest, FusionShiftMerge1_CUDA) { // t1 allocation: (split_factor + 1) * (split_factor + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor + 1); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1044,13 +1060,13 @@ TEST(NVFuserTest, FusionShiftMerge1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftMerge2_CUDA) { +TEST_F(NVFuserTest, FusionShiftMerge2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}); auto tv3 = shift(tv1, {-1, 1}); auto tv4 = add(tv2, tv3); @@ -1067,35 +1083,30 @@ TEST(NVFuserTest, FusionShiftMerge2_CUDA) { // t1 allocation: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1106,14 +1117,14 @@ TEST(NVFuserTest, FusionShiftMerge2_CUDA) { TORCH_CHECK(t4.allclose(outputs[0])); } -TEST(NVFuserTest, FusionShiftGlobal_CUDA) { +TEST_F(NVFuserTest, FusionShiftGlobal_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); auto tv3 = shift(tv1, {-1, 0}); auto tv4 = add(tv2, tv3); @@ -1132,17 +1143,18 @@ TEST(NVFuserTest, FusionShiftGlobal_CUDA) { // t1 allocation: (t1.size[0] + 1) * (t1.size[1] + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add); - TORCH_CHECK(def->as()->lhs()->isA()); - auto rhs = dynamic_cast(def->as()->rhs()); + dynamic_cast(alloc->shape().at(i)->definition()); + TORCH_CHECK( + def != nullptr && def->getBinaryOpType() == BinaryOpType::Add); + TORCH_CHECK(def->as()->lhs()->isA()); + auto rhs = dynamic_cast(def->as()->rhs()); TORCH_CHECK(rhs != nullptr && rhs->isConst()); int rhs_value = *rhs->value(); TORCH_CHECK(rhs_value == 1); @@ -1159,7 +1171,7 @@ TEST(NVFuserTest, FusionShiftGlobal_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1171,14 +1183,14 @@ TEST(NVFuserTest, FusionShiftGlobal_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { +TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {0, 1}); fusion.addOutput(tv3); @@ -1194,33 +1206,27 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { // t1 and t2 allocation: (split_factor1 + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { - TORCH_CHECK(alloc->shape().size() == 1); - auto def = - dynamic_cast(alloc->shape().at(0)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(0)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor1 + 1); } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 3; @@ -1229,15 +1235,15 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { +TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {1, 1}); fusion.addOutput(tv3); @@ -1271,35 +1277,30 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { // t1 and t2 allocation: (split_factor1 + 1) * (split_factor1 + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor1 + 1); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = shift(t0 + 1 + 2, {1, 1}); @@ -1307,7 +1308,7 @@ TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1327,7 +1328,7 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { tv_out = add(tv_out, tv); } - tv_out = div(tv_out, new Double(tvs.size() + 1)); + tv_out = div(tv_out, IrBuilder::create(tvs.size() + 1)); fusion.addOutput(tv_out); @@ -1361,35 +1362,30 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { // cache allocation: (split_factor1 + 2) * (split_factor2 + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == tv0_cache->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK( + size != nullptr && size->isConst() && + size->value().value() == split_factor[i] + 2); } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = t0; @@ -1401,7 +1397,7 @@ TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain1_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1416,15 +1412,15 @@ TEST(NVFuserTest, FusionShiftChain1_CUDA) { tv0->computeAt(tv2, -2); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = shift(shift(t0, {0, 1}), {0, 1}); @@ -1432,7 +1428,7 @@ TEST(NVFuserTest, FusionShiftChain1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain2_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1446,15 +1442,15 @@ TEST(NVFuserTest, FusionShiftChain2_CUDA) { tv0->computeAt(tv2, -2); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = shift(shift(t0, {0, 1}), {0, -1}); @@ -1462,13 +1458,13 @@ TEST(NVFuserTest, FusionShiftChain2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain3_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {0, 1}); auto tv3 = shift(tv2, {0, 1}); fusion.addOutput(tv3); @@ -1484,40 +1480,33 @@ TEST(NVFuserTest, FusionShiftChain3_CUDA) { // tv1: (split_factor + 2) // tv2: (split_factor + 1) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 1); for (int i = 0; i < 1; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK(size != nullptr && size->isConst()); if (tensor_name == 1) { - TORCH_CHECK(rhs_value == 2); + TORCH_CHECK(size->value().value() == split_factor + 2); } else if (tensor_name == 2) { - TORCH_CHECK(rhs_value == 1); + TORCH_CHECK(size->value().value() == split_factor + 1); } } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1528,7 +1517,7 @@ TEST(NVFuserTest, FusionShiftChain3_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftChain4_CUDA) { +TEST_F(NVFuserTest, FusionShiftChain4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1558,42 +1547,36 @@ TEST(NVFuserTest, FusionShiftChain4_CUDA) { // tv2: (split_factor + 7) * (split_factor + 7) // tv3: (split_factor + 4) * (split_factor + 4) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == 1 || tensor_name == 2) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK(size != nullptr && size->isConst()); + auto size_val = size->value().value(); if (tensor_name == 1) { - TORCH_CHECK(rhs_value == 9); + TORCH_CHECK(size_val == split_factor + 9); } else if (tensor_name == 2) { - TORCH_CHECK(rhs_value == 7); + TORCH_CHECK(size_val == split_factor + 7); } else if (tensor_name == 3) { - TORCH_CHECK(rhs_value == 4); + TORCH_CHECK(size_val == split_factor + 4); } } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = shift(t0, {1, -1}); @@ -1605,7 +1588,7 @@ TEST(NVFuserTest, FusionShiftChain4_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { +TEST_F(NVFuserTest, FusionShift5ptStencilChain_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1625,7 +1608,8 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { tv_stencil1 = add(tv_stencil1, tv); } - tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1)); + tv_stencil1 = div( + tv_stencil1, IrBuilder::create(tv_stencil1_shifts.size() + 1)); // Second stencil: Same 5pt stencil std::vector tv_stencil2_shifts; @@ -1638,7 +1622,8 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { tv_stencil2 = add(tv_stencil2, tv); } - tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1)); + tv_stencil2 = div( + tv_stencil2, IrBuilder::create(tv_stencil2_shifts.size() + 1)); auto tv_out = tv_stencil2; @@ -1682,41 +1667,34 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { // tv0_cache: (split_factor + 4) * (split_factor + 4) // tv_stencil1: (split_factor + 2) * (split_factor + 2) GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->irNodes()) { - if (auto alloc = dynamic_cast(kir_node.get())) { + for (const auto expr : gpulw.kernel()->unordered_exprs()) { + if (auto alloc = dynamic_cast(expr)) { auto tensor_name = alloc->buffer()->name(); if (tensor_name == tv0_cache->name() || tensor_name == tv_stencil1->name()) { TORCH_CHECK(alloc->shape().size() == 2); for (int i = 0; i < 2; ++i) { - auto def = - dynamic_cast(alloc->shape().at(i)->definition()); - auto lhs = dynamic_cast(def->as()->lhs()); - TORCH_CHECK(lhs != nullptr && lhs->isConst()); - int lhs_value = *lhs->value(); - auto rhs = dynamic_cast(def->as()->rhs()); - TORCH_CHECK(rhs != nullptr && rhs->isConst()); - int rhs_value = *rhs->value(); - TORCH_CHECK(lhs_value == split_factor[i]); + auto size = dynamic_cast(alloc->shape().at(i)); + TORCH_CHECK(size != nullptr && size->isConst()); if (tensor_name == tv0_cache->name()) { - TORCH_CHECK(rhs_value == 4); + TORCH_CHECK(size->value().value() == split_factor[i] + 4); } else if (tensor_name == tv_stencil1->name()) { - TORCH_CHECK(rhs_value == 2); + TORCH_CHECK(size->value().value() == split_factor[i] + 2); } } } } } - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto stencil1 = t0; @@ -1735,13 +1713,13 @@ TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) { } // Shift a reduced tensor -TEST(NVFuserTest, FusionShiftReduction1_CUDA) { +TEST_F(NVFuserTest, FusionShiftReduction1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -1758,7 +1736,7 @@ TEST(NVFuserTest, FusionShiftReduction1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1770,13 +1748,13 @@ TEST(NVFuserTest, FusionShiftReduction1_CUDA) { } // Parallelized version of FusionShiftReduction1 -TEST(NVFuserTest, FusionShiftReduction2_CUDA) { +TEST_F(NVFuserTest, FusionShiftReduction2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -1799,7 +1777,7 @@ TEST(NVFuserTest, FusionShiftReduction2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1810,13 +1788,13 @@ TEST(NVFuserTest, FusionShiftReduction2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { +TEST_F(NVFuserTest, FusionShiftRfactor1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = sum(tv1, {1}); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -1841,7 +1819,7 @@ TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -1852,7 +1830,7 @@ TEST(NVFuserTest, FusionShiftRfactor1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftBcast1_CUDA) { +TEST_F(NVFuserTest, FusionShiftBcast1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1877,7 +1855,7 @@ TEST(NVFuserTest, FusionShiftBcast1_CUDA) { std::vector inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t4 = t0.unsqueeze(-1).expand({numel_x, numel_y}) + t1; @@ -1886,7 +1864,7 @@ TEST(NVFuserTest, FusionShiftBcast1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftBcast2_CUDA) { +TEST_F(NVFuserTest, FusionShiftBcast2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1911,7 +1889,7 @@ TEST(NVFuserTest, FusionShiftBcast2_CUDA) { std::vector inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); @@ -1922,7 +1900,7 @@ TEST(NVFuserTest, FusionShiftBcast2_CUDA) { } // Combine ShiftBcast1 and ShiftBcast2 with parallelization -TEST(NVFuserTest, FusionShiftBcast3_CUDA) { +TEST_F(NVFuserTest, FusionShiftBcast3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1959,7 +1937,7 @@ TEST(NVFuserTest, FusionShiftBcast3_CUDA) { std::vector inputs = {t0, t1}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y}); @@ -1972,14 +1950,14 @@ TEST(NVFuserTest, FusionShiftBcast3_CUDA) { } // See issue #893 -TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { +TEST_F(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(2)); auto tv3 = add(tv1, tv2); auto tv4 = shift(tv3, {0, 1}); fusion.addOutput(tv4); @@ -1996,15 +1974,15 @@ TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -2016,14 +1994,14 @@ TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) { } // See issue #893. Top-level placement. -TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { +TEST_F(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv0, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(2)); auto tv3 = add(tv1, tv2); auto tv4 = shift(tv3, {1}); fusion.addOutput(tv4); @@ -2037,14 +2015,14 @@ TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -2055,14 +2033,14 @@ TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) { testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { +TEST_F(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); - auto tv2 = add(tv1, new Double(2)); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); auto tv3 = shift(tv2, {1}); fusion.addOutput(tv3); @@ -2093,7 +2071,7 @@ TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) { // along the Y dimension. The other 10 warps are used to load a 32x10 // tile, and all warps will do coalesced loads. No such optimization // is done in the fuser version. -TEST(NVFuserTest, FusionHdiff_CUDA) { +TEST_F(NVFuserTest, FusionHdiff_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2123,7 +2101,7 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { // T9 = T0 * 4 // T10 = T9 - T8 - auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors); + auto lap = sub(mul(inp, IrBuilder::create(4)), sum_of_neighbors); // T11 = shift(T10) // T12 = T11 - T10 @@ -2133,8 +2111,9 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { // T16 = T15 > 0 // T17 = T16 ? 0 : T12 auto flx_cond = - gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), new Double(0)); - auto flx0 = where(flx_cond, new Double(0), flx); + gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), + IrBuilder::create(0)); + auto flx0 = where(flx_cond, IrBuilder::create(0), flx); // T18 = shift(T10) // T19 = T18 - T10 @@ -2144,9 +2123,10 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { // T22 = T19 * T21 // T23 = T22 > 0 auto fly_cond = - gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), new Double(0)); + gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), + IrBuilder::create(0)); // T24 = T23 ? 0 : T19 - auto fly0 = where(fly_cond, new Double(0), fly); + auto fly0 = where(fly_cond, IrBuilder::create(0), fly); // T25 = shift(flx0) // T26 = T17 - T25 @@ -2233,9 +2213,6 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { } ///////////////////////////////// - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 101; int numel_y = 99; int numel_z = 10; @@ -2244,7 +2221,11 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); std::vector inputs = {inp_at, coeff_at}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto fuser_output = fe.runFusion(inputs)[0]; + // Trim the outer rim std::vector indices{ at::indexing::Slice(0, at::indexing::None), @@ -2273,7 +2254,7 @@ TEST(NVFuserTest, FusionHdiff_CUDA) { } } -TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { +TEST_F(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2303,7 +2284,7 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { // T9 = T0 * 4 // T10 = T9 - T8 - auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors); + auto lap = sub(mul(inp, IrBuilder::create(4)), sum_of_neighbors); // T11 = shift(T10) // T12 = T11 - T10 @@ -2313,8 +2294,9 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { // T16 = T15 > 0 // T17 = T16 ? 0 : T12 auto flx_cond = - gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), new Double(0)); - auto flx0 = where(flx_cond, new Double(0), flx); + gt(mul(flx, sub(shift(inp, {0, 0, -1}, false), inp)), + IrBuilder::create(0)); + auto flx0 = where(flx_cond, IrBuilder::create(0), flx); // T18 = shift(T10) // T19 = T18 - T10 @@ -2324,9 +2306,10 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { // T22 = T19 * T21 // T23 = T22 > 0 auto fly_cond = - gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), new Double(0)); + gt(mul(fly, sub(shift(inp, {0, -1, 0}, false), inp)), + IrBuilder::create(0)); // T24 = T23 ? 0 : T19 - auto fly0 = where(fly_cond, new Double(0), fly); + auto fly0 = where(fly_cond, IrBuilder::create(0), fly); // T25 = shift(flx0) // T26 = T17 - T25 @@ -2428,9 +2411,6 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { } ///////////////////////////////// - FusionExecutor fe; - fe.compileFusion(&fusion); - const int halo_extent = 2; const int numel_x = 64 + halo_extent * 2; const int numel_y = 64 + halo_extent * 2; @@ -2440,7 +2420,11 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options); at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options); std::vector inputs = {inp_at, coeff_at}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto fuser_output = fe.runFusion(inputs)[0]; + // Trim the outer rim std::vector indices{ at::indexing::Slice(0, at::indexing::None), @@ -2470,7 +2454,7 @@ TEST(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { } // 3x3 max pooling -TEST(NVFuserTest, FusionMaxPooling_CUDA) { +TEST_F(NVFuserTest, FusionMaxPooling_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2537,9 +2521,6 @@ TEST(NVFuserTest, FusionMaxPooling_CUDA) { max_tensor->axis(0)->parallelize(ParallelType::BIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int hw = 50; const int num_channels = 20; const int pooling_window = 3; @@ -2555,6 +2536,8 @@ TEST(NVFuserTest, FusionMaxPooling_CUDA) { aten_inp = at::abs(aten_inp); std::vector inputs = {aten_inp}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = at::max_pool2d( @@ -2563,7 +2546,7 @@ TEST(NVFuserTest, FusionMaxPooling_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGatherPadding1_CUDA) { +TEST_F(NVFuserTest, FusionGather1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2586,13 +2569,13 @@ TEST(NVFuserTest, FusionGatherPadding1_CUDA) { auto ref = gather(t0, window_shape, padding_width); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); TORCH_CHECK(ref.equal(outputs[0])); } -TEST(NVFuserTest, FusionGatherPadding2_CUDA) { +TEST_F(NVFuserTest, FusionGather2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2602,7 +2585,7 @@ TEST(NVFuserTest, FusionGatherPadding2_CUDA) { auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width); @@ -2629,7 +2612,7 @@ TEST(NVFuserTest, FusionGatherPadding2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -2639,129 +2622,747 @@ TEST(NVFuserTest, FusionGatherPadding2_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConv2DStatic_CUDA) { +TEST_F(NVFuserTest, FusionGather3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Input: [C, H, W] - auto inp = makeSymbolicTensor(3); - fusion.addInput(inp); - - // Weights: [K, C, 3, 3] - auto w = makeSymbolicTensor(4); - fusion.addInput(w); - - // Gather a neighbor tile of [3, 3] with padding size of 1 for each - // side of the spatial dimensions - auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}); - // inp_tile: [C, H, W, 1, 3, 3] - - auto inp_bc = - broadcast(inp_tile, {true, false, false, false, false, false, false}); - auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); - - auto inp_times_w = mul(inp_bc, w_bc); - - // Reduce the channel and neighbor tile dimensions - auto out = sum(inp_times_w, {1, 4, 5, 6}); - - fusion.addOutput(out); + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); - //////////////////////////////////// + const std::vector window_shape = {1, 3}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; - // Cache the input and weight tensors - auto inp_cache = inp->cache_after(); + auto tv1 = gather(tv0, window_shape, padding_width); - // Blocking the spatial dimensions - const int block_w = 16; - const int block_h = 4; - // Blocking the channel dimension - const int block_c = 8; + fusion.addOutput(tv1); - out->split(2, block_h); - out->split(4, block_w); - out->reorder({{3, 4}}); - // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + const int s1 = 11; + const int s2 = 13; - out->split(1, block_c); - // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); - auto out_rf = out->rFactor({1, -3, -2, -1}); - // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] - // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); - // Create a [block_x, block_y] tile on smem - inp_cache->computeAt(out, 4); - // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] - inp_cache->setMemoryType(MemoryType::Shared); + auto ref = gather(t0, window_shape, padding_width); + TORCH_CHECK(ref.equal(outputs[0])); +} - // Move Ci forward - out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); - inp_cache->computeAt(out_rf, 5); +TEST_F(NVFuserTest, FusionGather4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); - inp_tile->computeAt(out_rf, -1); - w->computeAt(out_rf, -1); + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); - out->axis(0)->parallelize(ParallelType::BIDx); - out->axis(1)->parallelize(ParallelType::TIDz); - out->axis(4)->parallelize(ParallelType::TIDy); - out->axis(5)->parallelize(ParallelType::TIDx); + const std::vector window_shape = {3, 3}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; - scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + auto tv1 = gather(tv0, window_shape, padding_width); - FusionExecutor fe; - fe.compileFusion(&fusion); + fusion.addOutput(tv1); - const int dim_h = 99; - const int dim_w = 101; - const int dim_c = 10; - const int dim_f = 20; + const int s1 = 11; + const int s2 = 13; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); - at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); - std::vector inputs = {at_inp, at_w}; + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); - auto cg_outputs = fe.runFusion(inputs); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); - at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis - auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); - at_out = at_out.squeeze(0); // drop the N axis + auto ref = gather(t0, window_shape, padding_width); - testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); + TORCH_CHECK(ref.equal(outputs[0])); } -// Mostly the same as the static conv test, but the shape of the weights, -// 3x3 in this case, is given dynamically -TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { +TEST_F(NVFuserTest, FusionGather5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Input: [C, H, W] - auto inp = makeSymbolicTensor(3); - fusion.addInput(inp); + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); - // Weights: [K, C, S, T] - auto w = makeSymbolicTensor(4); - fusion.addInput(w); + const std::vector window_shape = {3, 3}; + const std::vector> padding_width = {{1, 0}, {0, 1}}; - auto w_h = new Int(); - fusion.addInput(w_h); - auto w_w = new Int(); - fusion.addInput(w_w); + auto tv1 = gather(tv0, window_shape, padding_width); - auto pad_h = new Int(); - fusion.addInput(pad_h); - auto pad_w = new Int(); - fusion.addInput(pad_w); + fusion.addOutput(tv1); - // Gather a neighbor tile of [w_dim_h, w_dim_w] with padding + const int s1 = 11; + const int s2 = 13; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// Conv-like pattern with no padding +TEST_F(NVFuserTest, FusionGather6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {3, 4}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; + + auto tv1 = gather(tv0, window_shape, padding_width); + + fusion.addOutput(tv1); + + // Blocking the spatial dimensions + const int block_x = 16; + const int block_y = 8; + + auto tv0_cache = tv0->cache_after(); + auto out = tv1; + auto out_cache = out->cache_before(); + + out->split(1, block_x); + out->split(0, block_y); + out->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->axis(0)->parallelize(ParallelType::BIDy); + out->axis(1)->parallelize(ParallelType::BIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + const int s1 = 101; + const int s2 = 99; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// Conv-like pattern with irregular padding +TEST_F(NVFuserTest, FusionGather7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {3, 4}; + const std::vector> padding_width = {{0, 2}, {2, 1}}; + + auto tv1 = gather(tv0, window_shape, padding_width); + + fusion.addOutput(tv1); + + // Blocking the spatial dimensions + const int block_x = 16; + const int block_y = 8; + + auto tv0_cache = tv0->cache_after(); + auto out = tv1; + auto out_cache = out->cache_before(); + + out->split(1, block_x); + out->split(0, block_y); + out->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->axis(0)->parallelize(ParallelType::BIDy); + out->axis(1)->parallelize(ParallelType::BIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + const int s1 = 101; + const int s2 = 99; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + size.insert(size.end(), window_shape.begin(), window_shape.end()); + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// With no padding but with striding +TEST_F(NVFuserTest, FusionGather8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {2, 3}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; + const std::vector strides = {3, 3}; + + auto tv1 = gather(tv0, window_shape, padding_width, strides); + + fusion.addOutput(tv1); + + const int s1 = 11; + const int s2 = 13; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + for (const auto i : c10::irange(size.size())) { + size[i] = ceilDiv( + size[i] - window_shape[i] + 1 + padding_width[i][0] + + padding_width[i][1], + strides[i]); + } + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width, strides); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +// Similar to Gather8 but with splitting and parallelization +TEST_F(NVFuserTest, FusionGather9_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + const std::vector window_shape = {3, 4}; + const std::vector> padding_width = {{0, 0}, {0, 0}}; + const std::vector strides = {2, 2}; + + auto tv1 = gather(tv0, window_shape, padding_width, strides); + + fusion.addOutput(tv1); + + // Blocking the spatial dimensions + const int block_x = 16; + const int block_y = 8; + + auto tv0_cache = tv0->cache_after(); + auto out = tv1; + auto out_cache = out->cache_before(); + + out->split(1, block_x); + out->split(0, block_y); + out->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->axis(0)->parallelize(ParallelType::BIDy); + out->axis(1)->parallelize(ParallelType::BIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + const int s1 = 101; + const int s2 = 99; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector size({s1, s2}); + at::Tensor t0 = at::randn(size, options); + for (const auto i : c10::irange(size.size())) { + size[i] = ceilDiv( + size[i] - window_shape[i] + 1 + padding_width[i][0] + + padding_width[i][1], + strides[i]); + } + size.insert(size.end(), window_shape.begin(), window_shape.end()); + // Use a pre-allocated output tensor filled with 1 so that invalid + // writes to outside valid ranges can be detected + at::Tensor output = at::ones(size, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}, {output}); + + auto ref = gather(t0, window_shape, padding_width, strides); + + TORCH_CHECK(ref.equal(outputs[0])); +} + +TEST_F(NVFuserTest, FusionConv2D_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 3, 3] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [3, 3] with padding size of 1 for each + // side of the spatial dimensions + auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}}); + // inp_tile: [C, H, W, 1, 3, 3] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); + + //////////////////////////////////// + + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; + + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); + + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv2DNoPadding_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 3, 3] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [3, 3] with no padding + auto inp_tile = + gather(inp, {1, 3, 3}, {{0, 0}, {0, 0}, {0, 0}}, {1, 1, 1}, true); + // inp_tile: [C, H-2, W-2, 1, 3, 3] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); + + //////////////////////////////////// + + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; + + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); + + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + std::vector stride = {1, 1}; + std::vector padding = {0, 0}; + auto at_out = at::conv2d(at_inp, at_w, {}, stride, padding); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv2DNoPaddingStrided_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 3, 3] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [2, 2] with no padding and strides of + // [2, 2] + auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 0}, {0, 0}}, {1, 2, 2}); + // inp_tile: [C, H/2, W/2, 1, 2, 2] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); + + //////////////////////////////////// + + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; + + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); + + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + std::vector stride = {2, 2}; + std::vector padding = {0, 0}; + auto at_out = at::conv2d(at_inp, at_w, {}, stride, padding); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +// 5x5 followed by 3x3 +TEST_F(NVFuserTest, FusionConv2DChain_CUDA) { + const int dim_w1_h = 5; + const int dim_w1_w = 5; + const int dim_pad1_h = (dim_w1_h - 1) / 2; + const int dim_pad1_w = (dim_w1_w - 1) / 2; + const int dim_w2_h = 3; + const int dim_w2_w = 3; + const int dim_pad2_h = (dim_w2_h - 1) / 2; + const int dim_pad2_w = (dim_w2_w - 1) / 2; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [K1, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K2, K1, S1, T1] + auto w1 = makeSymbolicTensor(4); + fusion.addInput(w1); + + // Weights: [K3, K2, S2, T2] + auto w2 = makeSymbolicTensor(4); + fusion.addInput(w2); + + // Gather a neighbor tile of [w1_h, w1_w] with padding auto inp_tile = gather( inp, - {new Int(1), w_h, w_w}, - {{new Int(0), new Int(0)}, {pad_h, pad_h}, {pad_w, pad_w}}); - // inp_tile: [C, 1, H - w_h + 1, W - w_w + 1, w_h, w_w] + {1, dim_w1_h, dim_w1_w}, + {{0, 0}, {dim_pad1_h, dim_pad1_h}, {dim_pad1_w, dim_pad1_w}}); + // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false}); + + auto inp_times_w1 = mul(inp_bc, w1_bc); + + // Reduce the channel and neighbor tile dimensions + auto out1 = sum(inp_times_w1, {1, 4, 5, 6}); + + // Second conv + auto out1_tile = gather( + out1, + {1, dim_w2_h, dim_w2_w}, + {{0, 0}, {dim_pad2_h, dim_pad2_h}, {dim_pad2_w, dim_pad2_w}}); + + auto out1_bc = + broadcast(out1_tile, {true, false, false, false, false, false, false}); + auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false}); + + auto out1_times_w2 = mul(out1_bc, w2_bc); + + auto out2 = sum(out1_times_w2, {1, 4, 5, 6}); + + fusion.addOutput(out2); + + //////////////////////////////////// + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); + + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + + out2->split(2, block_h); + out2->split(4, block_w); + out2->reorder({{3, 4}}); + // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3] + + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out2, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); + + // Move Ci forward + out1->reorder({{5, 3}, {3, 4}, {4, 5}}); + out1->setMemoryType(MemoryType::Shared); + + inp_cache->computeAt(out1, 4); + + inp_tile->computeAt(out1, -1); + w1->computeAt(out1, -1); + + out1_tile->computeAt(out2, -1); + w2->computeAt(out2, -1); + + out2->axis(0)->parallelize(ParallelType::BIDx); + out2->axis(4)->parallelize(ParallelType::TIDy); + out2->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_k1 = 3; + const int dim_k2 = 5; + const int dim_k3 = 7; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options); + at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options); + at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options); + std::vector inputs = {at_inp, at_w1, at_w2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2); + auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1); + at_out2 = at_out2.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 2, 2] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [2, 2] with padding size of 1 only for + // the right side of the spatial dimensions. The left padding is + // zero so that the output axis stays the same. + auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}}); + // inp_tile: [C, H, W, 1, 2, 2] auto inp_bc = broadcast(inp_tile, {true, false, false, false, false, false, false}); @@ -2775,6 +3376,7 @@ TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { fusion.addOutput(out); //////////////////////////////////// + // Cache the input and weight tensors auto inp_cache = inp->cache_after(); @@ -2787,13 +3389,13 @@ TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { out->split(2, block_h); out->split(4, block_w); out->reorder({{3, 4}}); - // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3] + // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2] out->split(1, block_c); - // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] auto out_rf = out->rFactor({1, -3, -2, -1}); - // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3] + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] // out_rf: [K, Ci, Ho, Wo, Hi, Wi] // Create a [block_x, block_y] tile on smem @@ -2815,185 +3417,226 @@ TEST(NVFuserTest, FusionConv2DDynamic_CUDA) { scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 99; const int dim_w = 101; const int dim_c = 10; const int dim_f = 20; - const int dim_w_h = 3; - const int dim_w_w = 3; - const int dim_pad_h = (dim_w_h - 1) / 2; - const int dim_pad_w = (dim_w_w - 1) / 2; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); - at::Tensor at_w = at::randn({dim_f, dim_c, dim_w_h, dim_w_w}, options); - std::vector inputs = { - at_inp, at_w, dim_w_h, dim_w_w, dim_pad_h, dim_pad_w}; + at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); + std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); at_out = at_out.squeeze(0); // drop the N axis + // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas + // the fuser output has dim_h*dim_w. Drop the first elements to make + // it match with the fuser output. + std::vector indices{ + at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(1, at::indexing::None), + at::indexing::Slice(1, at::indexing::None)}; + at_out = at_out.index(indices); testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -// 5x5 followed by 3x3 -TEST(NVFuserTest, FusionConv2DDynamicChain_CUDA) { +TEST_F(NVFuserTest, FusionConv4x4Pad1x1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - // Input: [K1, H, W] + // Input: [C, H, W] auto inp = makeSymbolicTensor(3); fusion.addInput(inp); - // Weights: [K2, K1, S1, T1] - auto w1 = makeSymbolicTensor(4); - fusion.addInput(w1); + // Weights: [K, C, 4, 4] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); - // Weights: [K3, K2, S2, T2] - auto w2 = makeSymbolicTensor(4); - fusion.addInput(w2); + // Gather a neighbor tile of [4, 4] with padding size of 1 for both + // sides of the spatial dimensions. The resulting extent is + // decreased by one. + auto inp_tile = + gather(inp, {1, 4, 4}, {{0, 0}, {1, 1}, {1, 1}}, {1, 1, 1}, true); + // inp_tile: [C, H-1, W-1, 1, 4, 4] - auto w1_h = new Int(); - fusion.addInput(w1_h); - auto w1_w = new Int(); - fusion.addInput(w1_w); + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); - auto w2_h = new Int(); - fusion.addInput(w2_h); - auto w2_w = new Int(); - fusion.addInput(w2_w); + auto inp_times_w = mul(inp_bc, w_bc); - auto pad_h1 = new Int(); - fusion.addInput(pad_h1); - auto pad_w1 = new Int(); - fusion.addInput(pad_w1); + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); - auto pad_h2 = new Int(); - fusion.addInput(pad_h2); - auto pad_w2 = new Int(); - fusion.addInput(pad_w2); + fusion.addOutput(out); - // Gather a neighbor tile of [w1_h, w1_w] with padding - auto inp_tile = gather( - inp, - {new Int(1), w1_h, w1_w}, - {{new Int(0), new Int(0)}, {pad_h1, pad_h1}, {pad_w1, pad_w1}}); - // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w] + //////////////////////////////////// - auto inp_bc = - broadcast(inp_tile, {true, false, false, false, false, false, false}); - auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false}); + // Cache the input and weight tensors + auto inp_cache = inp->cache_after(); - auto inp_times_w1 = mul(inp_bc, w1_bc); + // Blocking the spatial dimensions + const int block_w = 16; + const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; - // Reduce the channel and neighbor tile dimensions - auto out1 = sum(inp_times_w1, {1, 4, 5, 6}); + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 4, 4] - // Second conv - auto out1_tile = gather( - out1, - {new Int(1), w2_h, w2_w}, - {{new Int(0), new Int(0)}, {pad_h2, pad_h2}, {pad_w2, pad_w2}}); + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 4, 4] - auto out1_bc = - broadcast(out1_tile, {true, false, false, false, false, false, false}); - auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false}); + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 4, 4] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] - auto out1_times_w2 = mul(out1_bc, w2_bc); + // Create a [block_x, block_y] tile on smem + inp_cache->computeAt(out, 4); + // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + inp_cache->setMemoryType(MemoryType::Shared); - auto out2 = sum(out1_times_w2, {1, 4, 5, 6}); + // Move Ci forward + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); - fusion.addOutput(out2); + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); + + const int dim_h = 99; + const int dim_w = 101; + const int dim_c = 10; + const int dim_f = 20; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 4, 4}, options); + std::vector inputs = {at_inp, at_w}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis + auto at_out = + at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 1, 1); + at_out = at_out.squeeze(0); // drop the N axis + + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConv4x5Pad1x2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Input: [C, H, W] + auto inp = makeSymbolicTensor(3); + fusion.addInput(inp); + + // Weights: [K, C, 4, 4] + auto w = makeSymbolicTensor(4); + fusion.addInput(w); + + // Gather a neighbor tile of [4, 5] with padding size of 1 and 2 for + // each side of the spatial dimensions. + auto inp_tile = + gather(inp, {1, 4, 5}, {{0, 0}, {1, 1}, {2, 2}}, {1, 1, 1}, true); + // inp_tile: [C, H-1, W, 1, 4, 5] + + auto inp_bc = + broadcast(inp_tile, {true, false, false, false, false, false, false}); + auto w_bc = broadcast(w, {false, false, true, true, true, false, false}); + + auto inp_times_w = mul(inp_bc, w_bc); + + // Reduce the channel and neighbor tile dimensions + auto out = sum(inp_times_w, {1, 4, 5, 6}); + + fusion.addOutput(out); //////////////////////////////////// + // Cache the input and weight tensors auto inp_cache = inp->cache_after(); // Blocking the spatial dimensions const int block_w = 16; const int block_h = 4; + // Blocking the channel dimension + const int block_c = 8; - out2->split(2, block_h); - out2->split(4, block_w); - out2->reorder({{3, 4}}); - // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3] + out->split(2, block_h); + out->split(4, block_w); + out->reorder({{3, 4}}); + // out: [K, C, Ho, Wo, Hi, Wi, 1, 4, 5] + + out->split(1, block_c); + // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 4, 5] + + auto out_rf = out->rFactor({1, -3, -2, -1}); + // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 4, 5] + // out_rf: [K, Ci, Ho, Wo, Hi, Wi] // Create a [block_x, block_y] tile on smem - inp_cache->computeAt(out2, 4); + inp_cache->computeAt(out, 4); // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] inp_cache->setMemoryType(MemoryType::Shared); // Move Ci forward - out1->reorder({{5, 3}, {3, 4}, {4, 5}}); - out1->setMemoryType(MemoryType::Shared); - - inp_cache->computeAt(out1, 4); - - inp_tile->computeAt(out1, -1); - w1->computeAt(out1, -1); - - out1_tile->computeAt(out2, -1); - w2->computeAt(out2, -1); + out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); + inp_cache->computeAt(out_rf, 5); - out2->axis(0)->parallelize(ParallelType::BIDx); - out2->axis(4)->parallelize(ParallelType::TIDy); - out2->axis(5)->parallelize(ParallelType::TIDx); + inp_tile->computeAt(out_rf, -1); + w->computeAt(out_rf, -1); - scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1}); + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(1)->parallelize(ParallelType::TIDz); + out->axis(4)->parallelize(ParallelType::TIDy); + out->axis(5)->parallelize(ParallelType::TIDx); - FusionExecutor fe; - fe.compileFusion(&fusion); + scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); const int dim_h = 99; const int dim_w = 101; - const int dim_k1 = 3; - const int dim_k2 = 5; - const int dim_k3 = 7; - const int dim_w1_h = 5; - const int dim_w1_w = 5; - const int dim_pad1_h = (dim_w1_h - 1) / 2; - const int dim_pad1_w = (dim_w1_w - 1) / 2; - const int dim_w2_h = 3; - const int dim_w2_w = 3; - const int dim_pad2_h = (dim_w2_h - 1) / 2; - const int dim_pad2_w = (dim_w2_w - 1) / 2; + const int dim_c = 10; + const int dim_f = 20; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); - at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options); - at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options); - at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options); - std::vector inputs = { - at_inp, - at_w1, - at_w2, - dim_w1_h, - dim_w1_w, - dim_w2_h, - dim_w2_w, - dim_pad1_h, - dim_pad1_w, - dim_pad2_h, - dim_pad2_w}; + at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 4, 5}, options); + std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis - auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2); - auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1); - at_out2 = at_out2.squeeze(0); // drop the N axis + auto at_out = + at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 1, {1, 2}); + at_out = at_out.squeeze(0); // drop the N axis - testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { +TEST_F(NVFuserTest, FusionConv4x4Pad1x1Stride4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3001,15 +3644,14 @@ TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { auto inp = makeSymbolicTensor(3); fusion.addInput(inp); - // Weights: [K, C, 2, 2] + // Weights: [K, C, 3, 3] auto w = makeSymbolicTensor(4); fusion.addInput(w); - // Gather a neighbor tile of [2, 2] with padding size of 1 only for - // the right side of the spatial dimensions. The left padding is - // zero so that the output axis stays the same. - auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}}); - // inp_tile: [C, H, W, 1, 2, 2] + // Gather a neighbor tile of [4, 4] with padding size of 1 for both + // sides of the spatial dimensions. Set the stride width as 4. + auto inp_tile = gather(inp, {1, 4, 4}, {{0, 0}, {1, 1}, {1, 1}}, {1, 4, 4}); + // inp_tile: [C, H/4, s4, W/4, s4, 1, 4, 4] auto inp_bc = broadcast(inp_tile, {true, false, false, false, false, false, false}); @@ -3030,43 +3672,49 @@ TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { // Blocking the spatial dimensions const int block_w = 16; const int block_h = 4; - // Blocking the channel dimension - const int block_c = 8; + const int block_c = 2; + // [K, C, H/s, W/s, 1, 4, 4] out->split(2, block_h); + // [K, C, H/s/block_h, block_h, W/s, 1, 4, 4] out->split(4, block_w); + // [K, C, H/s/block_h, block_h, W/s/block_w, block_w, 1, 4, 4] out->reorder({{3, 4}}); - // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2] - + // [K, C, H/s/block_h, W/s/block_w, block_h, block_w, 1, 4, 4] out->split(1, block_c); - // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] + // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, block_h, block_w, 1, 4, + // 4] + out->split(4, 1); + // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, + // 4, 4] auto out_rf = out->rFactor({1, -3, -2, -1}); - // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2] - // out_rf: [K, Ci, Ho, Wo, Hi, Wi] + // [K, C/block_c, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, 1, + // 4, 4] - // Create a [block_x, block_y] tile on smem - inp_cache->computeAt(out, 4); - // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi] + // out: [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w] + + inp_cache->computeAt(out, 5); inp_cache->setMemoryType(MemoryType::Shared); + // [K, block_c, H/s/block_h, W/s/block_w, 1, block_h, block_w, C/block_c, 1, + // 4, 4] - // Move Ci forward - out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}}); - inp_cache->computeAt(out_rf, 5); + // Move C/block_c before block_h/2 and share the domain from + // inp_cache to out_rf + out_rf->reorder({{7, 5}, {5, 6}, {6, 7}}); + inp_cache->computeAt(out_rf, 6); inp_tile->computeAt(out_rf, -1); w->computeAt(out_rf, -1); out->axis(0)->parallelize(ParallelType::BIDx); out->axis(1)->parallelize(ParallelType::TIDz); - out->axis(4)->parallelize(ParallelType::TIDy); - out->axis(5)->parallelize(ParallelType::TIDx); + out->axis(4)->parallelize(ParallelType::Unswitch); + out->axis(5)->parallelize(ParallelType::TIDy); + out->axis(6)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 99; const int dim_w = 101; const int dim_c = 10; @@ -3075,28 +3723,23 @@ TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(0); at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options); - at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options); + at::Tensor at_w = at::randn({dim_f, dim_c, 4, 4}, options); std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis - auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1); + auto at_out = + at::conv2d(at_inp.to(at::kDouble), at_w.to(at::kDouble), {}, 4, {1, 1}); at_out = at_out.squeeze(0); // drop the N axis - // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas - // the fuser output has dim_h*dim_w. Drop the first elements to make - // it match with the fuser output. - std::vector indices{ - at::indexing::Slice(0, at::indexing::None), - at::indexing::Slice(1, at::indexing::None), - at::indexing::Slice(1, at::indexing::None)}; - at_out = at_out.index(indices); testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } // POC implementation of im2col for 3-by-3 kernels -TEST(NVFuserTest, FusionIm2Col_CUDA) { +TEST_F(NVFuserTest, FusionIm2Col_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3147,9 +3790,6 @@ TEST(NVFuserTest, FusionIm2Col_CUDA) { scheduler_utils::parallelizeAllLike(out, {inp_cache, inp_tile}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 31; const int dim_w = 33; const int dim_c = 5; @@ -3160,6 +3800,8 @@ TEST(NVFuserTest, FusionIm2Col_CUDA) { at::Tensor at_inp = at::randn({dim_n, dim_c, dim_h, dim_w}, options); std::vector inputs = {at_inp}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); auto at_out = at::im2col(at_inp, {3, 3}, {1, 1}, {1, 1}, {1, 1}); @@ -3171,14 +3813,14 @@ TEST(NVFuserTest, FusionIm2Col_CUDA) { testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPadding1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3201,9 +3843,6 @@ TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { tv5->axis(-2)->parallelize(ParallelType::TIDy); scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3211,6 +3850,9 @@ TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -3226,14 +3868,14 @@ TEST(NVFuserTest, FusionShiftNoPadding1_CUDA) { } // Split and merge -TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPadding2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3256,9 +3898,6 @@ TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { tv5->axis(-1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3266,6 +3905,9 @@ TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -3281,14 +3923,14 @@ TEST(NVFuserTest, FusionShiftNoPadding2_CUDA) { } // Split and merge, then welford -TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPadding3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3316,9 +3958,6 @@ TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { tv_avg->axis(-1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv_avg, ir_utils::allTvs(&fusion)); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3327,7 +3966,11 @@ TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); + outputs[1] /= (numel_x - 2) * (numel_y - 2); auto t1 = t0 + 1; @@ -3346,13 +3989,13 @@ TEST(NVFuserTest, FusionShiftNoPadding3_CUDA) { } // Shift indexing and predication with contiguous merge -TEST(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, true); auto tv3 = shift(tv1, {-1, 1}, false); auto tv4 = add(tv2, tv3); @@ -3366,15 +4009,15 @@ TEST(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { tv2->setMemoryType(MemoryType::Global); tv3->setMemoryType(MemoryType::Global); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 9; int numel_y = 11; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3392,14 +4035,14 @@ TEST(NVFuserTest, FusionShiftNoPaddingContigMerge_CUDA) { testValidate(&fusion, {fuser_out}, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = shift(tv2, {1, -1}, false); auto tv4 = sum(tv3, {0, 1}); @@ -3422,9 +4065,6 @@ TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3}); - FusionExecutor fe; - fe.compileFusion(&fusion); - int numel_x = 99; int numel_y = 101; @@ -3433,6 +4073,9 @@ TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -3447,14 +4090,14 @@ TEST(NVFuserTest, FusionShiftNoPaddingChain_CUDA) { } // Rfactor is not allowed with partial domains -TEST(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { +TEST_F(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1, -1}, false); auto tv3 = sum(tv2, {0, 1}); fusion.addOutput(tv3); @@ -3466,7 +4109,61 @@ TEST(NVFuserTest, FusionShiftNoPaddingRfactor_CUDA) { ASSERT_ANY_THROW(tv3->rFactor({-2})); } -TEST(NVFuserTest, FusionPartialSplit1_CUDA) { +TEST_F(NVFuserTest, FusionShiftPadding1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = shift(tv1, {2, -2}, {1, 1}); + auto tv3 = shift(tv1, {-3, 2}, {2, 2}); + auto tv4 = add(tv2, tv3); + auto tv5 = sum(tv4, {0, 1}); + + fusion.addOutput(tv5); + + tv1->setMemoryType(MemoryType::Shared); + + tv5->split(0, 4); + tv5->split(-1, 8); + tv5->reorder({{1, 2}}); + + TransformPropagator::from(tv5); + + tv2->computeAt(tv5, -1); + tv3->computeAt(tv5, -1); + + tv5->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-2)->parallelize(ParallelType::TIDy); + scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion)); + + int numel_x = 99; + int numel_y = 101; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto t1 = t0 + 1; + auto t2 = shift(t1, {2, -2}); + auto t3 = shift(t1, {-3, 2}); + auto t4 = t2 + t3; + std::vector indices{ + at::indexing::Slice(1, -1), at::indexing::Slice(0, -1)}; + t4 = t4.index(indices); + auto ref = t4.sum(at::ArrayRef{0, 1}); + + testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPartialSplit1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3474,7 +4171,7 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { // [I] fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(0)); + auto tv1 = add(tv0, IrBuilder::create(0)); // [I] auto tv2 = shift(tv1, {1}, false); // [1:I] @@ -3504,9 +4201,6 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - // gridDim.x is ceilDiv(numel_x - 2, 8), not ceilDiv(numel_x, 8), // so it's going to be just 2 rather than 3. const int numel_x = 18; @@ -3527,6 +4221,9 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{at::indexing::Slice(1, -1)}; @@ -3538,21 +4235,21 @@ TEST(NVFuserTest, FusionPartialSplit1_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPartialSplit2_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(0)); + auto tv1 = add(tv0, IrBuilder::create(0)); auto tv2 = shift(tv1, {1}, false); auto tv3 = shift(tv1, {-1}, false); auto tv4 = add(tv2, tv3); fusion.addOutput(tv4); - auto tv5 = add(tv1, new Double(1)); - auto tv6 = add(tv5, new Double(1)); + auto tv5 = add(tv1, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); fusion.addOutput(tv6); tv4->split(0, 4, true, true); @@ -3568,14 +4265,14 @@ TEST(NVFuserTest, FusionPartialSplit2_CUDA) { } // 2D version of PartialSplit1 -TEST(NVFuserTest, FusionPartialSplit3_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(0)); + auto tv1 = add(tv0, IrBuilder::create(0)); auto tv2 = shift(tv1, {1, 2}, false); auto tv3 = shift(tv1, {-2, -1}, false); auto tv4 = add(tv2, tv3); @@ -3595,9 +4292,6 @@ TEST(NVFuserTest, FusionPartialSplit3_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int numel_x = 32 + 3; const int numel_y = 32 + 3; @@ -3606,6 +4300,9 @@ TEST(NVFuserTest, FusionPartialSplit3_CUDA) { at::manual_seed(0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3620,7 +4317,7 @@ TEST(NVFuserTest, FusionPartialSplit3_CUDA) { // Almost same fusion with Shift5ptStencilChain but non-padded shift // and partial split. -TEST(NVFuserTest, FusionPartialSplit4_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3641,7 +4338,8 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { tv_stencil1 = add(tv_stencil1, tv); } - tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1)); + tv_stencil1 = div( + tv_stencil1, IrBuilder::create(tv_stencil1_shifts.size() + 1)); // Second stencil: Same 5pt stencil std::vector tv_stencil2_shifts; @@ -3654,7 +4352,8 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { tv_stencil2 = add(tv_stencil2, tv); } - tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1)); + tv_stencil2 = div( + tv_stencil2, IrBuilder::create(tv_stencil2_shifts.size() + 1)); auto tv_out = tv_stencil2; @@ -3696,9 +4395,6 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { tv0_cache->setMemoryType(MemoryType::Shared); tv_stencil1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - // Input matrix size is 68x68, and the output is 64x64. Both // gridDim.x and gridim.y should be ceilDiv(numel - 4, // split_factor), which is 4. If full split is used, the grid @@ -3709,6 +4405,9 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3731,7 +4430,7 @@ TEST(NVFuserTest, FusionPartialSplit4_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPartialSplit5_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3743,7 +4442,7 @@ TEST(NVFuserTest, FusionPartialSplit5_CUDA) { fusion.addInput(tv0); auto tv1 = shift(tv0, {0, 1}, false); - auto tv2 = add(tv1, new Double(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); fusion.addOutput(tv2); @@ -3760,12 +4459,12 @@ TEST(NVFuserTest, FusionPartialSplit5_CUDA) { tv1->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x, numel_y}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3779,7 +4478,7 @@ TEST(NVFuserTest, FusionPartialSplit5_CUDA) { testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionPartialSplit6_CUDA) { +TEST_F(NVFuserTest, FusionPartialSplit6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3788,9 +4487,9 @@ TEST(NVFuserTest, FusionPartialSplit6_CUDA) { auto tv0 = makeConcreteTensor({numel_x}); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {1}, false); - auto tv3 = add(tv2, new Double(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); fusion.addOutput(tv3); @@ -3803,12 +4502,12 @@ TEST(NVFuserTest, FusionPartialSplit6_CUDA) { tv1->setMemoryType(MemoryType::Shared); tv2->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); std::vector indices{ @@ -3821,7 +4520,7 @@ TEST(NVFuserTest, FusionPartialSplit6_CUDA) { testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { +TEST_F(NVFuserTest, FusionShiftUnswitch1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3840,7 +4539,7 @@ TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { auto tv4 = shift(tv0, {-2, -2}); fusion.addOutput(tv4); - auto tv5 = add(tv0, new Double(1)); + auto tv5 = add(tv0, IrBuilder::create(1)); auto tv6 = shift(tv5, {0, -1}); fusion.addOutput(tv6); @@ -3862,7 +4561,7 @@ TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = shift(t0, {-1, 0}); @@ -3881,27 +4580,22 @@ TEST(NVFuserTest, FusionShiftUnswitch1_CUDA) { TORCH_CHECK(t6.equal(outputs[4])); } -TEST(NVFuserTest, FusionGatherUnswitch1_CUDA) { +TEST_F(NVFuserTest, FusionGatherUnswitch1_CUDA) { + const int tv1_gather = 3; + const int tv1_gather_pad = 1; + const int tv2_gather = 5; + const int tv2_gather_pad = 2; + Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1_gather_param = new Int(); - fusion.addInput(tv1_gather_param); - auto tv1_gather_pad_param = new Int(); - fusion.addInput(tv1_gather_pad_param); - auto tv1 = gather( - tv0, {tv1_gather_param}, {{tv1_gather_pad_param, tv1_gather_pad_param}}); + auto tv1 = gather(tv0, {tv1_gather}, {{tv1_gather_pad, tv1_gather_pad}}); fusion.addOutput(tv1); - auto tv2_gather_param = new Int(); - fusion.addInput(tv2_gather_param); - auto tv2_gather_pad_param = new Int(); - fusion.addInput(tv2_gather_pad_param); - auto tv2 = gather( - tv0, {tv2_gather_param}, {{tv2_gather_pad_param, tv2_gather_pad_param}}); + auto tv2 = gather(tv0, {tv2_gather}, {{tv2_gather_pad, tv2_gather_pad}}); fusion.addOutput(tv2); // Static gather @@ -3923,18 +4617,13 @@ TEST(NVFuserTest, FusionGatherUnswitch1_CUDA) { tv4->axis(1)->parallelize(ParallelType::TIDx); const int numel_x = 100; - const int tv1_gather = 3; - const int tv1_gather_pad = 1; - const int tv2_gather = 5; - const int tv2_gather_pad = 2; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({numel_x}, options); - std::vector inputs = { - t0, tv1_gather, tv1_gather_pad, tv2_gather, tv2_gather_pad}; + std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = gather(t0, {tv1_gather}, {{tv1_gather_pad, tv1_gather_pad}}); @@ -3950,7 +4639,7 @@ TEST(NVFuserTest, FusionGatherUnswitch1_CUDA) { TORCH_CHECK(t4.equal(outputs[3])); } -TEST(NVFuserTest, FusionGatherStrided1_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -3973,7 +4662,7 @@ TEST(NVFuserTest, FusionGatherStrided1_CUDA) { at::Tensor t0 = at::randn({s1, s2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); // tv1 has a stride dimension, so its number of dimensions should be @@ -4013,7 +4702,7 @@ TEST(NVFuserTest, FusionGatherStrided1_CUDA) { } // Split strided domain -TEST(NVFuserTest, FusionGatherStrided2_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4024,7 +4713,7 @@ TEST(NVFuserTest, FusionGatherStrided2_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4054,7 +4743,7 @@ TEST(NVFuserTest, FusionGatherStrided2_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4065,7 +4754,7 @@ TEST(NVFuserTest, FusionGatherStrided2_CUDA) { } // Outer split -TEST(NVFuserTest, FusionGatherStrided3_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4076,7 +4765,7 @@ TEST(NVFuserTest, FusionGatherStrided3_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4102,7 +4791,7 @@ TEST(NVFuserTest, FusionGatherStrided3_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4112,7 +4801,7 @@ TEST(NVFuserTest, FusionGatherStrided3_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionGatherStrided4_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4123,7 +4812,7 @@ TEST(NVFuserTest, FusionGatherStrided4_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); // Test propagation of split from one gather output to another auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4147,7 +4836,7 @@ TEST(NVFuserTest, FusionGatherStrided4_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4158,7 +4847,7 @@ TEST(NVFuserTest, FusionGatherStrided4_CUDA) { } // Same as GatherStrided1 but with stride != window -TEST(NVFuserTest, FusionGatherStrided5_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided5_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4181,7 +4870,7 @@ TEST(NVFuserTest, FusionGatherStrided5_CUDA) { at::Tensor t0 = at::randn({s1, s2}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto outputs = fe.runFusion({t0}); auto ref = gather(t0, window_shape, padding_width, strides); @@ -4190,7 +4879,7 @@ TEST(NVFuserTest, FusionGatherStrided5_CUDA) { } // Same as GatherStrided2 but with stride != window -TEST(NVFuserTest, FusionGatherStrided6_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided6_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4201,7 +4890,7 @@ TEST(NVFuserTest, FusionGatherStrided6_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4231,7 +4920,7 @@ TEST(NVFuserTest, FusionGatherStrided6_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4242,7 +4931,7 @@ TEST(NVFuserTest, FusionGatherStrided6_CUDA) { } // Same as GatherStrided4 but different strides -TEST(NVFuserTest, FusionGatherStrided7_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided7_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4252,7 +4941,7 @@ TEST(NVFuserTest, FusionGatherStrided7_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); // Use different strides auto tv2 = gather(tv1, window_shape, padding_width, {3}); @@ -4271,7 +4960,7 @@ TEST(NVFuserTest, FusionGatherStrided7_CUDA) { } // Same as GatherStrided2 but with unswitch -TEST(NVFuserTest, FusionGatherStrided8_CUDA) { +TEST_F(NVFuserTest, FusionGatherStrided8_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4282,7 +4971,7 @@ TEST(NVFuserTest, FusionGatherStrided8_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); @@ -4316,7 +5005,7 @@ TEST(NVFuserTest, FusionGatherStrided8_CUDA) { std::vector inputs = {t0}; FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto t1 = t0 + 1; @@ -4327,7 +5016,7 @@ TEST(NVFuserTest, FusionGatherStrided8_CUDA) { } // Chained strided gather. Not supported yet. -TEST(NVFuserTest, FusionGatherStridedChain_CUDA) { +TEST_F(NVFuserTest, FusionGatherStridedChain_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4339,7 +5028,7 @@ TEST(NVFuserTest, FusionGatherStridedChain_CUDA) { auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = gather(tv1, window_shape, padding_width, strides); // Reduce gathered window @@ -4356,10 +5045,7 @@ TEST(NVFuserTest, FusionGatherStridedChain_CUDA) { ASSERT_ANY_THROW(GpuLower gpulw(&fusion)); } -TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionMaxPoolingStrided_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4379,7 +5065,7 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { auto max_tensor = reductionOp( BinaryOpType::Max, {-3, -2, -1}, - new Double(std::numeric_limits::lowest()), + IrBuilder::create(std::numeric_limits::lowest()), inp_tile); fusion.addOutput(max_tensor); @@ -4410,9 +5096,6 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { inp_cache->setMemoryType(MemoryType::Shared); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int hw = 50; const int num_channels = 20; const int pooling_window = 3; @@ -4428,6 +5111,8 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { aten_inp = at::abs(aten_inp); std::vector inputs = {aten_inp}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto outputs = fe.runFusion(inputs); auto ref = at::max_pool2d( @@ -4436,10 +5121,7 @@ TEST(NVFuserTest, FusionMaxPoolingStrided_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } +TEST_F(NVFuserTest, FusionConv2DStaticStrided_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4518,9 +5200,6 @@ TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf}); - FusionExecutor fe; - fe.compileFusion(&fusion); - const int dim_h = 99; const int dim_w = 101; const int dim_c = 10; @@ -4532,6 +5211,8 @@ TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options); std::vector inputs = {at_inp, at_w}; + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); auto cg_outputs = fe.runFusion(inputs); at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis @@ -4541,14 +5222,14 @@ TEST(NVFuserTest, FusionConv2DStaticStrided_CUDA) { testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); - auto tv1 = add(tv0, new Double(1)); + auto tv1 = add(tv0, IrBuilder::create(1)); auto tv2 = shift(tv1, {-1}); fusion.addOutput(tv2); @@ -4564,7 +5245,7 @@ TEST(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { at::Tensor t0 = at::randn({24}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto ref = shift((t0 + 1), {-1}); @@ -4572,7 +5253,7 @@ TEST(NVFuserTest, FusionNonDivisibleHalo1_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } -TEST(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { +TEST_F(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -4621,7 +5302,7 @@ TEST(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { at::Tensor t0 = at::randn({111, 222}, options); FusionExecutor fe; - fe.compileFusion(&fusion); + fe.compileFusion(&fusion, {t0}); auto cg_outputs = fe.runFusion({t0}); auto t1 = gather(t0, {3, 3}, {{1, 1}, {1, 1}}); @@ -4632,6 +5313,59 @@ TEST(NVFuserTest, FusionNonDivisibleHalo2_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {t4}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionGather9ptStencilDoubleBuffering_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = gather(tv0, {3, 3}, {{1, 1}, {1, 1}}); + auto tv2 = sum(tv1, {-2, -1}); + auto tv3 = div(tv2, IrBuilder::create(9)); + + auto out = tv3; + + fusion.addOutput(out); + + auto tv0_cache = tv0->cache_after(); + + tv0_cache->setMemoryType(MemoryType::Shared); + + out->split(-2, 4); + out->split(-1, 32); + out->reorder({{1, 2}, {2, 1}}); + TransformPropagator::from(out); + + tv0->computeAt(out, 2); + + out->axis(3)->parallelize(ParallelType::TIDx); + out->axis(2)->parallelize(ParallelType::TIDy); + out->axis(0)->parallelize(ParallelType::BIDx); + + scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion)); + + tv0_cache->doubleBuffer(); + + int numel_x = 99; + int numel_y = 101; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto t1 = gather(t0, {3, 3}, {{1, 1}, {1, 1}}); + auto t2 = sum(t1, {-2, -1}); + auto t3 = t2 / 9; + auto ref = t3; + + testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/test/cpp/jit/test_gpu_validator.h b/test/cpp/jit/test_gpu_validator.h index 5923e384e39..4b01f361cfc 100644 --- a/test/cpp/jit/test_gpu_validator.h +++ b/test/cpp/jit/test_gpu_validator.h @@ -4,6 +4,7 @@ #include #include +#include #include namespace torch { @@ -11,6 +12,25 @@ namespace jit { namespace fuser { namespace cuda { +inline bool deviceMajorMinorCheck(int major, int minor = 0) { + auto dev_prop = at::cuda::getDeviceProperties(0); + if (dev_prop->major < major || + (dev_prop->major == major && dev_prop->minor < minor)) { + return false; + } + return true; +} + +class NVFuserTest : public ::testing::Test { + protected: + void SetUp() override { + // requires PASCAL or newer + if (!deviceMajorMinorCheck(6)) { + GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; + } + } +}; + struct ValidationConstants { // Tolerances generated from randn + add + sum fusion // compared against double precision @@ -66,8 +86,8 @@ std::pair getTolerance( } else { // Reduction case size_t entry = 0; - while (sum_tolerance_entry[entry][0] < reduction_size && - entry < sum_tolerance_entry.size()) { + while (entry < sum_tolerance_entry.size() && + sum_tolerance_entry[entry][0] < reduction_size) { entry++; } double abs_tol = 0.0; @@ -221,7 +241,7 @@ class ReductionSizeMapper : private IterVisitor { } void handle(Expr* expr) override { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return; } diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 1d78a19c5ad..299c738c570 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -3,6 +3,10 @@ import unittest import os import random +import enum +import copy +from functools import reduce +import operator import torch from torch.nn import functional @@ -20,6 +24,8 @@ import numpy as np import math +from torch.autograd.gradcheck import gradcheck + from typing import List CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')) @@ -465,21 +471,25 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) def _unary_test_helper(self, operation, dtype, random_data): - shape = (4, 8, 32, 32) + gradient_check = (dtype == torch.float64) and random_data + shape = (8, 7) + torch.cuda.manual_seed_all(211) # need additional def of t for boolean ops def t(x: torch.Tensor, y: torch.Tensor): o = x * y + o = o + 5e-3 o = operation(o) return o - y = torch.tensor([1], device="cuda").to(dtype) + y = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check) + y = y.to(dtype=dtype) if random_data: - x = torch.randn(shape, dtype=torch.float32, device="cuda") + x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check) if dtype in self.int_types: # prefer a larger variance for integer types - x *= 5 + x = x * 5 x = x.to(dtype=dtype) else: x = self.special_values.to(dtype=dtype) @@ -491,14 +501,14 @@ def t(x: torch.Tensor, y: torch.Tensor): t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) - if dtype in self.support_tensor_dtypes: + jit_o = t_jit(x, y) + if gradient_check: + gradcheck(t_jit, [x, y], nondet_tol=1e-5) + elif dtype in self.support_tensor_dtypes: self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) o = t(x, y) self.assertEqual(o.dtype, jit_o.dtype) - self.assertEqual(o, jit_o, msg=f""" - failing case: - {dtype} {operation} {x} - """) + self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2)) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -651,6 +661,16 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = o + z return o + def t_int(x: torch.Tensor, y: torch.Tensor): + o = operation(x, y) + o = 2 + o + return o + + def t_float(x: torch.Tensor, y: torch.Tensor): + o = operation(x, y) + o = 2. + o + return o + shape = (4, 32, 32) if random_data: x = (torch.randn(shape, dtype=torch.float, device="cuda") * 5).to(dtype_arg1) @@ -665,14 +685,16 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): if operation in div_like and (dtype_arg2 == torch.int32 or dtype_arg2 == torch.int64): y[y == 0] = 1 - o = t(x, y, z) - t_jit = torch.jit.script(t) - jit_o = t_jit(x, y, z) - jit_o = t_jit(x, y, z) + for test_fn in [t, t_int, t_float]: + o = t(x, y, z) + t_jit = torch.jit.script(t) + jit_o = t_jit(x, y, z) + jit_o = t_jit(x, y, z) + jit_o = t_jit(x, y, z) - self.assertEqual(o.dtype, jit_o.dtype) - self.assertEqual(o, jit_o) - self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) + self.assertEqual(o.dtype, jit_o.dtype) + self.assertEqual(o, jit_o) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -887,6 +909,21 @@ def test_ternary_ops_type_promotion(self): self._ternary_test_helper(op, dtypes, True) # random data self._ternary_test_helper(op, dtypes, False) # special numbers + # We can't test the scalar version of rsub from python + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") + def test_rsub(self): + x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") + y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") + + def rsub(x: torch.Tensor, y: torch.Tensor): + o = torch.rsub(x, y) + o = o * 2. + return o + + rsub_jit = torch.jit.script(rsub) + self._run_helper(rsub_jit, rsub, x, y) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") # legacy fuser does not work for rand_like, see issue #34361 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -1008,6 +1045,8 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): torch._C._jit_set_nvfuser_guard_mode(old_guard) @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") def test_random_topo(self): os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1" self.assertTrue(runDefaultTestWithSeed(28449)) @@ -1272,7 +1311,6 @@ def forward(self, x: torch.Tensor): self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error)) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) - @unittest.skipIf(True, "codegen failure awaiting fix") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -1287,7 +1325,6 @@ def test_native_layer_norm(self): norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)] self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine) - @unittest.skipIf(True, "codegen failure awaiting fix") @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -2306,7 +2343,7 @@ def t(x: torch.Tensor, y: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 - o = torch.matmul(o, y) + o = torch._C._nn.linear(o, y) return o x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True) @@ -2380,7 +2417,7 @@ def t(x: torch.Tensor, y: torch.Tensor): o = x * 2.0 o = torch.softmax(o, dim=-1) o = o * 3.0 - o = torch.matmul(o, y) + o = torch._C._nn.linear(o, y) return o x = torch.randn(8, 4, dtype=torch.bfloat16, device='cuda', requires_grad=True) @@ -2731,8 +2768,8 @@ def __init__(self, num_features=10, affine=True, track_running_stats=True): track_running_stats=track_running_stats).to(dtype=dtype) def forward(self, x): - o = x * 2.0 - o = self.bn(o) + o = self.bn(x) + o = o * 2.0 return o x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_() @@ -3055,6 +3092,7 @@ def _run_fwd_helper(self, func, ops, *args): for op in ops: self.assertGraphContainsExactly(graph, op, 0) + @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -3068,7 +3106,7 @@ def t(x: torch.Tensor): o1 = x + 1.0 o2 = x * 0.5 return o1, o2 - self._run_fwd_helper(t, ['aten::add'], x) + self._run_fwd_helper(t, ['aten::add', 'aten::mul'], x) def t2(x: torch.Tensor, y: torch.Tensor): o1 = x.sum(0) @@ -3076,7 +3114,6 @@ def t2(x: torch.Tensor, y: torch.Tensor): return o1, o2 self._run_fwd_helper(t2, ['aten::sum', 'aten::mul'], x, y) - @unittest.skipIf(True, "Fixed in PR #68804") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -3120,6 +3157,343 @@ def t(x: torch.Tensor, y: torch.Tensor): graph = jitted.graph_for(x, y) self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) + def _bias_view_relu_helper(self, shape, output_shape, dtype, device, error): + class BiasViewRelu(torch.nn.Module): + def __init__(self): + super(BiasViewRelu, self).__init__() + self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) + with torch.no_grad(): + self.bias.fill_(10) + + def forward(self, inputs : torch.Tensor, view_shape : List[int]): + o = inputs + self.bias + o = o.view(view_shape) + return torch.relu(o) + + t = BiasViewRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + # profiling + jit_o = t_jit(x, output_shape) + # optimization + jit_o = t_jit(x, output_shape) + # final + jit_o = t_jit(x, output_shape) + # eager - baseline + o = t(x, output_shape) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x, output_shape) + + has_inferred_dimension = any([dim == -1 for dim in output_shape]) + if has_inferred_dimension: + # prohibit fusing when view_shape contains an inferred dimension + self.assertGraphContainsExactly(graph, FUSION_GROUP, 0) + self.assertGraphContainsExactly(graph, 'prim::view_copy', 0) + else: + self.assertGraphContains(graph, FUSION_GUARD) + self.assertGraphContains(graph, 'prim::view_copy', True) + + def _alias_bias_view_relu_helper(self, shape, output_shape, dtype, device, error): + class BiasViewRelu(torch.nn.Module): + def __init__(self): + super(BiasViewRelu, self).__init__() + self.bias = torch.nn.Parameter(torch.randn(shape, dtype=dtype, device=device), requires_grad=False) + with torch.no_grad(): + self.bias.fill_(10) + + def forward(self, inputs : torch.Tensor, view_shape : List[int]): + o = inputs.view(view_shape) + inputs = inputs * self.bias + return torch.relu(o) + + t = BiasViewRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + # profiling + jit_o = t_jit(x, output_shape) + # optimization + jit_o = t_jit(x, output_shape) + # final + jit_o = t_jit(x, output_shape) + # eager - baseline + o = t(x, output_shape) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x, output_shape) + self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) + self.assertGraphContainsExactly(graph, 'prim::view_copy', 0) + + # generate random view given original view + def _random_view(self, original_view, max_len=8, max_views=10000): + class Moves(enum.Enum): + Merge = 0 + Split = 1 + Broadcast = 2 + ImplicitBroadcast = 3 + Keep = 4 + + def valid(old_view, new_view): + old_view_size = reduce(operator.mul, old_view) + new_view_size = reduce(operator.mul, new_view) + return old_view_size == new_view_size + + # given a random starting number, find the nearest divisor + def find_nearest_divisor(N): + if 2 >= (N - 1): + return -1 + result = random.randint(2, N - 1) + while (N % result) != 0: + result += 1 + return result + + complete_views = set([tuple(original_view)]) + + to_visit = [] + # empty new view, curent originaal view, start pos=0, move count = 0, last_move + to_visit.append(([], original_view, 0, [], Moves.Keep)) + + # depth-first search of view shapes, starting from the original view + while len(to_visit) > 0 and len(complete_views) < max_views: + new_view, old_view, odx, move_list, last_move = to_visit[-1] + to_visit.pop() + + # iterate over each move type + for idx in range(len(Moves)): + state = Moves(idx) + new_view_clone = copy.deepcopy(new_view) + old_view_clone = copy.deepcopy(old_view) + new_move_list = move_list + [state] + new_odx = odx + + # Update state using Move state + if state == Moves.Keep: + new_size = old_view_clone[odx] + new_view_clone.append(new_size) + new_odx += 1 + + elif state == Moves.Merge: + if odx + 1 < len(old_view_clone): + new_size = old_view_clone[odx] * old_view_clone[odx + 1] + new_view_clone.append(new_size) + new_odx += 2 + else: + continue + + elif state == Moves.Broadcast and last_move != Moves.Broadcast: + new_view_clone.append(1) + + elif state == Moves.Split: + new_size = find_nearest_divisor(old_view_clone[odx]) + if new_size == -1: + continue + new_view_clone.append(new_size) + old_view_clone[odx] = int(old_view[odx] / new_size) + + if old_view_clone[odx] == 1: + new_odx += 1 + + elif state == Moves.ImplicitBroadcast: + old_view_clone.insert(odx + 1, 1) + new_size = old_view[odx] * 1 + new_view_clone.append(new_size) + new_odx += 2 + + if new_odx < len(old_view_clone) and len(new_move_list) < max_len: + to_visit.append((new_view_clone, old_view_clone, new_odx, new_move_list, state)) + elif (valid(original_view, new_view_clone)): + final_new_view = tuple(new_view_clone) + complete_views.add(final_new_view) + return list(complete_views) + + # ndims - number of dimensions + # test_fn - view test function + def _view_test_generator(self, ndims, test_fn): + # create random tensor + # max value for each dimension + max_size = 10e7 + max_value = max(int(pow(max_size, 1. / ndims)), 1) + sizes = [random.randint(1, max_value) for idx in range(ndims)] + x = torch.randn(sizes) + + original_sizes = list(x.size()) + all_views = self._random_view(original_sizes) + random.shuffle(all_views) + + max_samples = 20 + max_views = min(len(all_views), max_samples) + total = 0 + correct = 0 + # test random combinations of compatible views + for idx in range(max_views): + for jdx in range(idx + 1, max_views): + total += 1 + test_fn(all_views[idx], all_views[jdx], torch.float, 'cuda', 1e-6) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_view(self): + torch._C._jit_set_nvfuser_guard_mode(True) + self._bias_view_relu_helper([2, 3, 4, 5], [-1, 4, 5], torch.float, 'cuda', 1e-6) + for ndims in range(1, 5): + self._view_test_generator(ndims, self._bias_view_relu_helper) + self._alias_bias_view_relu_helper([2, 3, 4, 5], [1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) + + def _bias_squeeze_relu_helper(self, shape, dtype, device, error): + class BiasSqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasSqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = inputs + bias + o = torch.squeeze(o) + return torch.relu(o) + + t = BiasSqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x) + self.assertGraphContains(graph, FUSION_GUARD) + self.assertGraphContains(graph, 'prim::squeeze_copy', True) + + def _alias_bias_squeeze_relu_helper(self, shape, dtype, device, error): + class BiasSqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasSqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = torch.squeeze(inputs) + inputs = inputs * bias + return torch.relu(o) + + t = BiasSqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x, bias) + self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) + self.assertGraphContainsExactly(graph, 'prim::squeeze_copy', 0) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_squeeze(self): + self._bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) + self._alias_bias_squeeze_relu_helper([1, 6, 1, 2, 2, 5, 1], torch.float, 'cuda', 1e-6) + + def _bias_unsqueeze_relu_helper(self, shape, dtype, device, error): + class BiasUnsqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasUnsqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = inputs + bias + o = torch.unsqueeze(o, 0) + return torch.relu(o) + + t = BiasUnsqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x) + self.assertGraphContains(graph, FUSION_GUARD) + self.assertGraphContains(graph, 'prim::unsqueeze_copy', True) + + def _alias_bias_unsqueeze_relu_helper(self, shape, dtype, device, error): + class BiasUnsqueezeRelu(torch.nn.Module): + def __init__(self): + super(BiasUnsqueezeRelu, self).__init__() + + def forward(self, inputs : torch.Tensor, bias : torch.Tensor): + o = torch.squeeze(inputs) + o = torch.unsqueeze(inputs, 0) + inputs = inputs * bias + return torch.relu(o) + + t = BiasUnsqueezeRelu() + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.randn(shape, dtype=dtype, device=device, requires_grad=False) + t_jit = torch.jit.script(t) + + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + jit_o = t_jit(x, bias) + o = t(x, bias) + + self.assertEqual(o.dtype, jit_o.dtype) + self.assertTrue(self._compare("comparing output failed", o, jit_o, error)) + graph = t_jit.graph_for(x) + self.assertGraphContainsExactly(graph, FUSION_GUARD, 0) + self.assertGraphContainsExactly(graph, 'prim::unsqueeze_copy', 0) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_unsqueeze(self): + self._bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6) + self._alias_bias_unsqueeze_relu_helper([2, 3, 4, 5], torch.float, 'cuda', 1e-6) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_alias_pass_fix(self): + x = torch.randn(4, 24, 2, 2, dtype=torch.float, device="cuda") + w = torch.randn(24, 24, 1, 1, dtype=torch.float, device="cuda") + b = torch.randn(24, dtype=torch.float, device="cuda") + + def t(x, w, b): + b2 = b + 1.0 + o = torch.conv2d(x, w, b2) + return o + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, w, b) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_squeeze_negative_dim(self): + x = torch.randn(4, 24, 1, 2, dtype=torch.float, device="cuda") + + def t(x): + o = x + 1.0 + o = o.squeeze(-2) + o = o * 2.0 + return o + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -3154,6 +3528,138 @@ def t(x, y, s): # sibling fusion should be disabled with the flag self.assertGraphContainsExactly(t_jit.graph_for(x, y, s), FUSION_GUARD, 0) + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_build_shape_expression_native_dropout(self): + x = torch.randn(4, 2, device="cuda") + + def t(x): + o, mask = torch.native_dropout(x, 0.0, True) + o1 = o.sigmoid() + o2 = mask.float().sigmoid() + return (o1, o2) + + t_jit = torch.jit.script(t) + + jit_o = t_jit(x) + jit_o = t_jit(x) + o = t(x) + for oo, jit_oo in zip(o, jit_o): + self.assertEqual(oo.dtype, jit_oo.dtype) + self.assertEqual(oo, jit_oo) + self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_scalar_tensor_permuted(self): + x = torch.randn(4, 2, 3, device="cuda").permute([1, 2, 0]) + y = torch.tensor(1.0, device="cuda") + + with nvfuser_singleton_fusion(True): + def t(x, y): + return x + y + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_cpu_scalar(self): + x = torch.randn(4, 2, 3, device="cuda") + y = torch.tensor(1.0, device="cpu") + z = torch.tensor(2.0, device="cpu") + + with nvfuser_singleton_fusion(True): + # testing cpu scalar tensor promotion + def t(x, y): + return x + y + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y) + + # scalar cpu tensor add should NOT be fused + @torch.jit.script + def t1(y, z): + return y * z + for _ in range(5): + t1(y, z) + self.assertGraphContainsExactly(t1.graph_for(y, z), FUSION_GUARD, 0) + + # everything, including scalar cpu tensor add should be fused + @torch.jit.script + def t2(x, y, z): + tmp = y + z + return tmp + x + for _ in range(5): + t2(x, y, z) + self.assertGraphContainsExactly(t2.graph_for(x, y, z), 'aten::add', 0) + self.assertGraphContainsExactly(t2.graph_for(x, y, z), FUSION_GUARD, 1) + + # 'cpu_tmp = y + z' shouldn't be fused. + @torch.jit.script + def t3(x, y, z): + cpu_tmp = y + z + out = x + y + return cpu_tmp, out + for _ in range(5): + t3(x, y, z) + self.assertGraphContainsExactly(t3.graph_for(x, y, z), FUSION_GUARD, 1) + self.assertGraphContainsExactly(t3.graph_for(x, y, z), 'aten::add', 1) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_shape_expression(self): + x = torch.randn(4, 2, 1, 3, device="cuda") + + def t_unsqueeze(x): + t0 = x.relu() + t1 = t0.unsqueeze(1) + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + def t_squeeze(x): + t0 = x.relu() + t1 = t0.squeeze() + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + def t_squeeze_dim(x): + t0 = x.relu() + t1 = t0.squeeze(-2) + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + # squeezing a non-size 1 dimension should be a no op + def t_squeeze_dim_no_op(x): + t0 = x.relu() + t1 = t0.squeeze(1) + t2 = t1 + 1.0 + t3 = t1.size() + return t2, t3 + + def run(fn): + jit_fn = torch.jit.script(fn) + jit_o = jit_fn(x) + jit_o = jit_fn(x) + jit_o = jit_fn(x) + o = fn(x) + # output 0 is a tensor, so we check dtype and value + self.assertEqual(o[0].dtype, jit_o[0].dtype) + self.assertEqual(o[0], jit_o[0]) + # output 1 is shape + self.assertEqual(o[1], jit_o[1]) + self.assertGraphContainsExactly(jit_fn.graph_for(x), FUSION_GUARD, 1) + + for t in [t_unsqueeze, t_squeeze, t_squeeze_dim, t_squeeze_dim_no_op]: + run(t) + class TestPassManagerCudaFuser(JitTestCase): @unittest.skipIf(not RUN_CUDA, "requires CUDA") diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index da78a7ceb4b..f63e4ea1668 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -635,7 +635,9 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp", "torch/csrc/jit/codegen/cuda/instrumentation.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", + "torch/csrc/jit/codegen/cuda/ir_builder.cpp", "torch/csrc/jit/codegen/cuda/ir_cloner.cpp", + "torch/csrc/jit/codegen/cuda/ir_container.cpp", "torch/csrc/jit/codegen/cuda/ir_graphviz.cpp", "torch/csrc/jit/codegen/cuda/ir_nodes.cpp", "torch/csrc/jit/codegen/cuda/ir_iostream.cpp", @@ -645,28 +647,32 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/kernel_cache.cpp", "torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp", "torch/csrc/jit/codegen/cuda/kernel_ir.cpp", - "torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp", - "torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp", + "torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp", "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp", - "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower_allocation.cpp", + "torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp", "torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp", + "torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp", "torch/csrc/jit/codegen/cuda/lower_index.cpp", "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp", "torch/csrc/jit/codegen/cuda/lower_loops.cpp", "torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp", "torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp", "torch/csrc/jit/codegen/cuda/lower_predicate.cpp", + "torch/csrc/jit/codegen/cuda/lower_replace_size.cpp", "torch/csrc/jit/codegen/cuda/lower_shift.cpp", "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp", + "torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp", "torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp", "torch/csrc/jit/codegen/cuda/lower_unroll.cpp", "torch/csrc/jit/codegen/cuda/lower_utils.cpp", "torch/csrc/jit/codegen/cuda/lower_validation.cpp", + "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", "torch/csrc/jit/codegen/cuda/non_divisible_split.cpp", + "torch/csrc/jit/codegen/cuda/ops/alias.cpp", "torch/csrc/jit/codegen/cuda/ops/composite.cpp", "torch/csrc/jit/codegen/cuda/ops/normalization.cpp", "torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp", diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 2c9925cf893..d9bf46b51c7 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -23,14 +24,15 @@ Val* newScalar(ValType vtype, DataType dtype) { case (ValType::Scalar): switch (dtype) { case DataType::Bool: - return new Bool(); + return IrBuilder::create(); case DataType::Double: case DataType::Float: case DataType::Half: case DataType::BFloat16: - return new Double(); + return IrBuilder::create(); + case DataType::Int32: case DataType::Int: - return new Int(); + return IrBuilder::create(); default: break; } @@ -103,10 +105,10 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { } for (const auto dim_i : c10::irange(out_domain.size())) { if (extent_vals[dim_i] != nullptr) { - out_domain[dim_i] = new IterDomain( - new Int(start_offsets[dim_i]), + out_domain[dim_i] = IrBuilder::create( + IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i], - new Int(stop_offsets[dim_i]), + IrBuilder::create(stop_offsets[dim_i]), ParallelType::Serial, iter_types[dim_i]); } else { @@ -121,13 +123,17 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { break; } } - out_domain[dim_i] = - new IterDomain(new Int(0), new Int(1), ParallelType::Serial, itype); + out_domain[dim_i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), + ParallelType::Serial, + itype); } } - return new TensorView( - new TensorDomain(out_domain, std::vector(out_domain.size(), true)), + return IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), dtype); } @@ -195,7 +201,7 @@ Val* castOp(DataType dtype, Val* v1) { } Val* out = newValLike(v1, dtype); - new UnaryOp(UnaryOpType::Cast, out, v1); + IrBuilder::create(UnaryOpType::Cast, out, v1); return out; } @@ -219,7 +225,7 @@ Val* unaryOp(UnaryOpType type, Val* v1) { // } Val* out = newValLike(v1, v1->getDataType().value()); - new UnaryOp(type, out, v1); + IrBuilder::create(type, out, v1); return out; } @@ -379,7 +385,7 @@ Val* binaryOp(BinaryOpType type, Val* v1, Val* v2, DataType common_dtype) { } else { out = newScalar(out_vtype, out_dtype); } - new BinaryOp(type, out, vals[0], vals[1]); + IrBuilder::create(type, out, vals[0], vals[1]); return out; } @@ -589,7 +595,7 @@ static TensorView* newForReduction( " of tensor ", tv); - new_domain.push_back(new IterDomain( + new_domain.push_back(IrBuilder::create( id->start(), id->extent(), id->stopOffset(), @@ -597,12 +603,12 @@ static TensorView* newForReduction( isReduction ? IterType::Reduction : id->getIterType())); } - TensorDomain* td = - new TensorDomain(new_domain, std::vector(new_domain.size(), true)); + TensorDomain* td = IrBuilder::create( + new_domain, std::vector(new_domain.size(), true)); data_type = data_type == DataType::Null ? tv->getDataType().value() : data_type; - return new TensorView(td, data_type); + return IrBuilder::create(td, data_type); } TensorView* reductionOp( @@ -652,7 +658,7 @@ TensorView* reductionOp( out_type, " and ", init_type); - new ReductionOp(reduction_op_type, init, out, tv); + IrBuilder::create(reduction_op_type, init, out, tv); if (keep_dim) { auto tv_root = TensorDomain::noReductions(tv->getRootDomain()); @@ -673,9 +679,9 @@ TensorView* sum( Val* init = nullptr; auto dtype = v1->getDataType().value(); if (isFloatingPointType(dtype)) { - init = new Double(0.0); + init = IrBuilder::create(0.0); } else if (isIntegralType(dtype)) { - init = new Int(0); + init = FusionGuard::getCurFusion()->zeroVal(); } else { TORCH_CHECK( false, @@ -693,13 +699,13 @@ TensorView* max( Val* init = nullptr; switch (v1->getDataType().value()) { case (DataType::Double): - init = new Double(std::numeric_limits::lowest()); + init = IrBuilder::create(std::numeric_limits::lowest()); break; case (DataType::Float): - init = new Double(std::numeric_limits::lowest()); + init = IrBuilder::create(std::numeric_limits::lowest()); break; case (DataType::Int): - init = new Int(INT_MIN); + init = IrBuilder::create(INT_MIN); break; default: TORCH_CHECK( @@ -718,13 +724,13 @@ TensorView* min( Val* init = nullptr; switch (v1->getDataType().value()) { case (DataType::Double): - init = new Double(DBL_MAX); + init = IrBuilder::create(DBL_MAX); break; case (DataType::Float): - init = new Double(FLT_MAX); + init = IrBuilder::create(FLT_MAX); break; case (DataType::Int): - init = new Int(INT_MAX); + init = IrBuilder::create(INT_MAX); break; default: TORCH_CHECK( @@ -767,9 +773,9 @@ TensorView* broadcast( size_t iinp = 0, ibdim = 0; while (ibdim < is_broadcast_dim.size()) { if (is_broadcast_dim[ibdim]) { - out_domain.push_back(new IterDomain( - new Int(0), - new Int(1), + out_domain.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), ParallelType::Serial, IterType::BroadcastWithoutStride)); } else { @@ -779,10 +785,11 @@ TensorView* broadcast( ibdim++; } - TensorView* out_tensor = new TensorView( - new TensorDomain(out_domain, std::vector(out_domain.size(), true)), + TensorView* out_tensor = IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); - new BroadcastOp(out_tensor, inp, is_broadcast_dim); + IrBuilder::create(out_tensor, inp, is_broadcast_dim); return out_tensor; } @@ -799,6 +806,10 @@ WelfordResult Welford( TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor"); TORCH_CHECK(axes.size() > 0, "No reduction axis specified"); + if (init_N == nullptr) { + init_N = FusionGuard::getCurFusion()->zeroVal(); + } + // Initial values for welford op are tensors, so their dims have to match the // output dim, // i.e. original_dims - dims_to_be_reduced @@ -819,8 +830,8 @@ WelfordResult Welford( init_avg_val = init_avg; init_var_val = init_var; } else { - init_avg_val = new Double(0); - init_var_val = new Double(0); + init_avg_val = IrBuilder::create(0); + init_var_val = IrBuilder::create(0); } // Check and collect reduction axes @@ -847,7 +858,7 @@ WelfordResult Welford( TensorView* out_var = newForReduction(tv, uint_axes); TensorView* out_N = newForReduction(tv, uint_axes, DataType::Int); - new WelfordOp( + IrBuilder::create( out_avg, out_var, out_N, /*out var/avg/count */ @@ -856,7 +867,7 @@ WelfordResult Welford( init_N, /*init var/avg/count */ tv, nullptr, - new Int(1)); /*in var/avg/count */ + FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */ return WelfordResult(out_avg, out_var, out_N); } @@ -888,10 +899,11 @@ TensorView* transpose( out_domain[i] = in_id->clone(); } - TensorView* out_tensor = new TensorView( - new TensorDomain(out_domain, std::vector(out_domain.size(), true)), + TensorView* out_tensor = IrBuilder::create( + IrBuilder::create( + out_domain, std::vector(out_domain.size(), true)), inp->getDataType().value()); - new TransposeOp(out_tensor, inp, new2old); + IrBuilder::create(out_tensor, inp, new2old); return out_tensor; } @@ -938,7 +950,7 @@ TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* v3) { return arithOpOverloads(sub_alpha, v1, v2, v3); } // lerp -TORCH_CUDA_CU_API Val* lerp(Val* start, Val* end, Val* weight) { +Val* lerp(Val* start, Val* end, Val* weight) { auto vals = maybeBroadcast({start, end, weight}); Val* intrm1 = sub(vals[1], vals[0]); Val* intrm2 = mul(vals[2], intrm1); @@ -1024,7 +1036,8 @@ Val* where(Val* c, Val* v1, Val* v2) { } else { out = newScalar(out_vtype, out_dtype); } - new TernaryOp(TernaryOpType::Where, out, vals[0], vals[1], vals[2]); + IrBuilder::create( + TernaryOpType::Where, out, vals[0], vals[1], vals[2]); return out; } @@ -1064,7 +1077,8 @@ Val* threshold(Val* in, Val* thresh, Val* value) { value = optionalCast(in->getDataType().value(), value); Val* out = newValLike(in, in->getDataType().value()); - new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value); + IrBuilder::create( + TernaryOpType::Threshold, out, in, thresh, value); return out; } @@ -1084,7 +1098,7 @@ Val* clamp(Val* in, Val* min_val, Val* max_val) { max_val = optionalCast(in->getDataType().value(), max_val); Val* out = newValLike(in, in->getDataType().value()); - new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val); + IrBuilder::create(TernaryOpType::Clamp, out, in, min_val, max_val); return out; } @@ -1186,125 +1200,157 @@ TensorView* sum_to(TensorView* in, const std::vector& sum_to_size) { } TensorView* shift(TensorView* inp, const std::vector& offsets, bool pad) { + // When pad is false, no padding is given. When it is true, padding + // sizes are set so that output domains have the same extents as + // input domains. + std::vector pad_width(offsets.size(), 0); + if (pad) { + for (const auto i : c10::irange(offsets.size())) { + pad_width[i] = std::abs(offsets[i]); + } + } + return shift(inp, offsets, pad_width); +} + +TensorView* shift( + TensorView* inp, + const std::vector& offsets, + const std::vector& pad_width_param) { + auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); + const auto ndims = inp_dom.size(); + + auto pad_width = pad_width_param; + // Default padding is set so that the extent is kept unchanged + if (pad_width.empty()) { + pad_width = offsets; + for (auto& p : pad_width) { + p = std::abs(p); + } + } + TORCH_CHECK( - TensorDomain::noReductions(inp->getRootDomain()).size() == offsets.size(), + ndims == offsets.size(), "Invalid shift offsets, number of entries in offsets expected to be ", - TensorDomain::noReductions(inp->getRootDomain()).size(), + ndims, " but received ", offsets.size()); + TORCH_CHECK( + ndims == pad_width.size(), + "Invalid padding width list, number of entries in pad_width expected to be ", + ndims, + " but received ", + pad_width.size()); + + std::for_each(pad_width.begin(), pad_width.end(), [](const auto& pad) { + TORCH_CHECK(pad >= 0, "Padding width must be >= 0: ", pad); + }); + TensorView* out = nullptr; - if (pad) { - out = newValLike(inp, inp->getDataType().value())->as(); - } else { - auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); - const auto ndims = inp_dom.size(); - std::vector out_dom; - for (const auto i : c10::irange(ndims)) { - const auto inp_axis = inp_dom[i]; - const auto offset = offsets[i]; - if (offset == 0) { - out_dom.push_back(inp_axis->clone()); - continue; - } + std::vector out_dom; + for (const auto i : c10::irange(ndims)) { + const auto inp_axis = inp_dom[i]; + const auto offset = offsets[i]; + const auto pad = pad_width[i]; - Int* current_start_offset = dynamic_cast(inp_axis->start()); - TORCH_INTERNAL_ASSERT( - current_start_offset != nullptr && current_start_offset->isConst(), - "Invalid IterDomain start value:", - current_start_offset); + if (offset == 0) { + out_dom.push_back(inp_axis->clone()); + continue; + } - Int* current_stop_offset = dynamic_cast(inp_axis->stopOffset()); - TORCH_INTERNAL_ASSERT( - current_stop_offset != nullptr && current_stop_offset->isConst(), - "Invalid IterDomain stop offset value:", - current_stop_offset); - - const auto cur_start_offset_value = current_start_offset->value().value(); - const auto cur_stop_offset_value = current_stop_offset->value().value(); - - Val* out_start_offset = nullptr; - Val* out_stop_offset = nullptr; - - if (offset > 0) { - // shift to right; extent remains the same, start and stop - // positions are moved right - out_start_offset = new Int(cur_start_offset_value + offset); - out_stop_offset = - new Int(std::max(cur_stop_offset_value - offset, int64_t(0))); - } else { - // shift to left; extent remains the same, start and stop - // positions are moved left - out_start_offset = - new Int(std::max(cur_start_offset_value + offset, int64_t(0))); - out_stop_offset = new Int(cur_stop_offset_value - offset); - } + Int* current_start_offset = dynamic_cast(inp_axis->start()); + TORCH_INTERNAL_ASSERT( + current_start_offset != nullptr && current_start_offset->isConst(), + "Invalid IterDomain start value:", + current_start_offset); - out_dom.push_back(new IterDomain( - out_start_offset, - inp_axis->extent(), - out_stop_offset, - ParallelType::Serial, - inp_axis->getIterType())); + Int* current_stop_offset = dynamic_cast(inp_axis->stopOffset()); + TORCH_INTERNAL_ASSERT( + current_stop_offset != nullptr && current_stop_offset->isConst(), + "Invalid IterDomain stop offset value:", + current_stop_offset); + + const auto cur_start_offset_value = current_start_offset->value().value(); + const auto cur_stop_offset_value = current_stop_offset->value().value(); + + int64_t out_start_offset = 0; + int64_t out_stop_offset = 0; + + if (offset > 0) { + // shift to right; extent remains the same, start and stop + // positions are moved right + out_start_offset = cur_start_offset_value + offset - pad; + out_stop_offset = std::max(cur_stop_offset_value - offset, int64_t(0)); + // If pad > offset, the extent of the output ID could be larger than the + // input, and the start offset of the output domain could become + // negative, which is not supported. + TORCH_CHECK( + out_start_offset >= 0, + "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ", + pad, + ". Shift: ", + offset, + "."); + } else { + // shift to left; extent remains the same, start and stop + // positions are moved left + out_start_offset = std::max(cur_start_offset_value + offset, int64_t(0)); + out_stop_offset = cur_stop_offset_value - offset - pad; + // Similar to the above case whwere offset is positive, if pad > + // -offset (note offset is negative), the extent of the output + // ID could be larger than the input, and the stop offset of the + // output domain could become negative. + TORCH_CHECK( + out_stop_offset >= 0, + "Invalid shift offset and padding. Padding must not be larger than the absolute extent of shift offset. Padding: ", + pad, + ". Shift: ", + offset, + "."); } - out = new TensorView( - new TensorDomain(out_dom, std::vector(out_dom.size(), true)), - inp->getDataType().value()); + out_dom.push_back(IrBuilder::create( + IrBuilder::create(out_start_offset), + inp_axis->extent(), + IrBuilder::create(out_stop_offset), + ParallelType::Serial, + inp_axis->getIterType())); } - new ShiftOp(out, inp, offsets, pad); - return out; -} - -namespace { -std::vector convertToIntVector(const std::vector& x) { - std::vector converted; - std::transform(x.begin(), x.end(), std::back_inserter(converted), [](int x) { - return new Int(x); - }); - return converted; -} -} // namespace + out = IrBuilder::create( + IrBuilder::create( + out_dom, std::vector(out_dom.size(), true)), + inp->getDataType().value()); -TensorView* gather( - TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides) { - std::vector window_shape_int = convertToIntVector(window_shape); - std::vector> pad_width_int; - std::transform( - pad_width.begin(), - pad_width.end(), - std::back_inserter(pad_width_int), - [](const std::vector& x) { return convertToIntVector(x); }); - return gather(inp, window_shape_int, pad_width_int, strides); + IrBuilder::create(out, inp, offsets, pad_width); + return out; } namespace { -// Return a new TensorDomain with given root domains. Apply strides if -// necessary. With non-unit strides, strided domains become an rfactor -// domain. +// Return a new TensorDomain with given root domains. Apply +// strides if necessary. With non-unit strides, strided domains become an +// rfactor domain. TensorDomain* generateTensorDomainWithStrides( const std::vector& root_domains, - const std::vector& strides) { + const std::vector& strides, + bool skip_unit_stride) { std::vector strided_domains; // If strides are just unit strides, don't apply striding - if (strides.empty() || std::all_of(strides.begin(), strides.end(), [](int s) { - return s == 1; - })) { - return new TensorDomain( + if (strides.empty() || + (skip_unit_stride && + std::all_of( + strides.begin(), strides.end(), [](int s) { return s == 1; }))) { + return IrBuilder::create( root_domains, std::vector(root_domains.size(), true)); } for (const auto i : c10::irange(root_domains.size())) { auto root_dom = root_domains.at(i); - if (i >= strides.size() || strides[i] == 1) { + if (i >= strides.size() || (skip_unit_stride && strides[i] == 1)) { strided_domains.push_back(root_dom); continue; } @@ -1317,7 +1363,7 @@ TensorDomain* generateTensorDomainWithStrides( auto contig_vector_size = strided_domains.size(); - auto strided_td = new TensorDomain( + auto strided_td = IrBuilder::create( root_domains, strided_domains, strided_domains, @@ -1330,9 +1376,10 @@ TensorDomain* generateTensorDomainWithStrides( TensorView* gather( TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides) { + const std::vector& window_shape, + const std::vector>& pad_width, + const std::vector& strides, + bool trim_out_of_bounds) { auto inp_dom = TensorDomain::noReductions(inp->getRootDomain()); const auto ndims = inp_dom.size(); @@ -1343,6 +1390,10 @@ TensorView* gather( " but received ", window_shape.size()); + std::for_each(window_shape.begin(), window_shape.end(), [](const auto& w) { + TORCH_CHECK(w > 0, "Window size must be > 0: ", w); + }); + TORCH_CHECK( ndims == pad_width.size(), "Invalid pad width: number of entries expected to be ", @@ -1354,6 +1405,10 @@ TensorView* gather( TORCH_CHECK( p.size() == 2, "Each entry of pad_width must have two non-negative integers."); + std::for_each(p.begin(), p.end(), [](const auto& p_left_or_right) { + TORCH_CHECK( + p_left_or_right >= 0, "Padding must be >= 0: ", p_left_or_right); + }); }); TORCH_CHECK( @@ -1363,6 +1418,10 @@ TensorView* gather( " but received ", strides.size()); + std::for_each(strides.begin(), strides.end(), [](const auto& s) { + TORCH_CHECK(s > 0, "Stride must be > 0: ", s); + }); + std::vector out_root_domains; std::vector out_gather_dom; @@ -1371,40 +1430,57 @@ TensorView* gather( const auto window_dim = window_shape[i]; const auto pad_left = pad_width[i][0]; const auto pad_right = pad_width[i][1]; + // This may be over-conservative TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt()); + const auto inp_stop_offset = inp_axis->stopOffset()->getInt(); + TORCH_INTERNAL_ASSERT( + inp_stop_offset.has_value(), + "Dynamic stop offset not supported: ", + inp_axis); + const auto extent_adjustment = window_dim - 1 - pad_left - pad_right; + TORCH_CHECK( + extent_adjustment >= 0, + "Invalid gather window and padding as output extent would be larger than input.", + " Window: ", + window_dim, + ". Padding left: ", + pad_left, + ". Padding right: ", + pad_right); + const auto out_stop_offset = inp_stop_offset.value() + extent_adjustment; Val* out_axis_dim = nullptr; - if (window_dim->isConst() && pad_left->isConst() && pad_right->isConst()) { - const int64_t extent_adjustment = - -(-window_dim->value().value() + 1 + pad_left->value().value() + - pad_right->value().value()); - out_axis_dim = extent_adjustment == 0 - ? inp_axis->extent() - : sub(inp_axis->extent(), new Int(extent_adjustment)); - } else { - out_axis_dim = - add(add(sub(inp_axis->extent(), window_dim), new Int(1)), - add(pad_left, pad_right)); - } - // TODO: out_axis_dim is assumed to be the same as the extent of - // the input domain. Throw an error if it isn't the case. - out_root_domains.push_back(new IterDomain( - new Int(0), - out_axis_dim, + out_root_domains.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + inp_axis->extent(), + IrBuilder::create(out_stop_offset), ParallelType::Serial, inp_axis->getIterType())); // create a new axis for the gathered domain - out_gather_dom.push_back(new IterDomain( - new Int(0), window_dim, ParallelType::Serial, IterType::Gather)); + out_gather_dom.push_back(IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create(window_dim), + ParallelType::Serial, + IterType::Gather)); } out_root_domains.insert( out_root_domains.end(), out_gather_dom.begin(), out_gather_dom.end()); - auto out_td = generateTensorDomainWithStrides(out_root_domains, strides); + TensorDomain* out_td = nullptr; + + if (trim_out_of_bounds) { + // If no stride vector is given, just use stride 1. It does not do + // any striding effect, but out-of-bounds values are trimmed. + auto s = strides.empty() ? std::vector(ndims, 1) : strides; + out_td = generateTensorDomainWithStrides(out_root_domains, strides, false); + } else { + out_td = generateTensorDomainWithStrides(out_root_domains, strides, true); + } - auto out_tv = new TensorView(out_td, inp->getDataType().value()); + auto out_tv = + IrBuilder::create(out_td, inp->getDataType().value()); - new GatherOp(out_tv, inp, window_shape, pad_width); + IrBuilder::create(out_tv, inp, window_shape, pad_width); return out_tv; } diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 5652d68eab8..1f18f65666a 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -114,7 +114,9 @@ TORCH_CUDA_CU_API WelfordResult Welford( const std::vector& axes, TensorView* init_avg = nullptr, TensorView* init_var = nullptr, - Int* init_N = new Int(0)); + // Initializes to 0 in function definition, doing this so we don't have to + // import IrBuilder just for this one interface. + Int* init_N = nullptr); // UNARY OPERATIONS // abs @@ -484,19 +486,27 @@ TORCH_CUDA_CU_API TensorView* sum_to( //! t1[i, j] = 0, otherwise //! //! The pad option controls how out-of-boundary accesses are -//! handled. When pad is true, shifting works as if the source tensor -//! is padded by zero. Otherwise, it does not modify the output tensor -//! region whose source coordinates are out-of-boundry. In both cases, -//! the size of output tensor does not change. However, when pad is -//! false, the start or stop value of the shifted axis is adjusted -//! accordingly. For example, when a shift offset is one, the axis start -//! value would be incremented by one. +//! handled. It specifies how many zeros are logically padded. If no +//! pad option is given, it automatically pads the input tensor so +//! that the output tensor has the same extent for each axis. //! -//! \param pad If true, out-of-boundary access returns zero. +//! When a padding value is smaller than the absolute value of a shift +//! offset, the output axis still has the same extent but its start or +//! stop offset is moved inward to signify those outside of the offset +//! are invalid. +//! +//! It is not allowed to use padding values that are larger than shift +//! offsets, which would mean output extentes would be larger than +//! input extents TORCH_CUDA_CU_API TensorView* shift( TensorView* inp, const std::vector& offsets, - bool pad = true); + const std::vector& pad_width = {}); + +TORCH_CUDA_CU_API TensorView* shift( + TensorView* inp, + const std::vector& offsets, + bool pad); //! Gather a window of nearby elements for each element. //! @@ -508,8 +518,13 @@ TORCH_CUDA_CU_API TensorView* shift( //! implemented with strided split, whose outer output domain becomes //! the root domain for subsequent consumers. The inner output domain //! becomes a Stride domain, which is ignored by subsequent consumers. +//! Only valid input ranges are fed into strided splits. //! -//! Example: +//! When trim_out_of_bounds is true, the values at the first and last +//! ends that are outside of the start and stop offsets are +//! effetively trimmed by partial split by 1. +//! +//! Example 1: //! t0: 2D tensor of [N, M] //! t1 = gather(t0, {1, 3}, {{0, 0}, {1, 1}}); //! @@ -517,23 +532,34 @@ TORCH_CUDA_CU_API TensorView* shift( //! t1: [N, M, 1, 3] //! t1[i, j, k, l] = The value at the window position of [k, l] //! for t0[i, j] -TORCH_CUDA_CU_API TensorView* gather( - TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides = {}); - -//! Gather a window of nearby elements for each element. //! -//! Same as the another gather interface but with Int* parameters. +//! Example 2.1 (without trimming): +//! t0: 2D tensor of [N, M] +//! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}); +//! +//! then: +//! t1: [N (stop offset: 1), M (stop offset: 1, 2, 2)] +//! +//! Example 2.1 (with trimming) +//! t0: 2D tensor of [N, M] +//! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}, true); +//! +//! then: +//! t1: [ceilDiv(N - 1, 1), ceilDiv(M - 1, 1), 2, 2] +//! +//! Example 3: +//! t0: 2D tensor of [N, M] +//! t1 = gather(t0, {3, 3}, {{0, 0}, {0, 0}}, {3, 3}); +//! +//! then: +//! t1: [ceilDiv(N - 2, 3), ceilDiv(M - 2, 3), 2, 2] //! -//! TODO: Remove this interface as we do not intend to support dynamic -//! window shapes at this moment. TORCH_CUDA_CU_API TensorView* gather( TensorView* inp, - const std::vector& window_shape, - const std::vector>& pad_width, - const std::vector& strides = {}); + const std::vector& window_shape, + const std::vector>& pad_width, + const std::vector& strides = {}, + bool trim_out_of_bounds = false); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 709c810efe3..67926e92672 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -19,7 +20,7 @@ namespace codegen { namespace { -class CudaKernelGenerator : private kir::IrVisitor { +class CudaKernelGenerator : private OptOutConstDispatch { static constexpr const char* kTab = " "; public: @@ -45,7 +46,7 @@ class CudaKernelGenerator : private kir::IrVisitor { code_ << "__global__ void " << kernel_name << "("; - std::vector params; + std::vector params; // Inputs & Outputs for (auto val : kernel_->inputs()) { @@ -56,13 +57,16 @@ class CudaKernelGenerator : private kir::IrVisitor { } // Generate parameter declarations - for (kir::Val* val : params) { - if (const auto tv = dynamic_cast(val)) { - code_ << "Tensor<" << val->dtype() << ", " - << TensorDomain::noReductions( - tv->fuserTv()->getMaybeRFactorDomain()) - .size() + for (Val* val : params) { + if (const auto tv = dynamic_cast(val)) { + if (tv->isCpuScalar()) { + code_ << " CpuScalarTensor<" << val->dtype() << "> " << varName(tv); + } else { + code_ + << "Tensor<" << val->dtype() << ", " + << TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size() << "> " << varName(tv); + } } else { TORCH_INTERNAL_ASSERT(val->isScalar()); // NOLINT (LLVM bug 48525) TORCH_INTERNAL_ASSERT(val->definition() == nullptr); @@ -76,17 +80,17 @@ class CudaKernelGenerator : private kir::IrVisitor { // Global buffers for (auto allocate : kernel_summary.global_allocations) { - TORCH_INTERNAL_ASSERT(allocate->buffer()->isA()); - const auto tv = allocate->buffer()->as(); + TORCH_INTERNAL_ASSERT(allocate->buffer()->isA()); + const auto tv = allocate->buffer()->as(); const auto& maybe_rfactor_domain = tv->domain()->hasRFactor() - ? tv->domain()->rfactorDomain() - : tv->domain()->rootDomain(); + ? tv->domain()->getRFactorDomain() + : tv->domain()->getRootDomain(); const auto nDims = std::count_if( maybe_rfactor_domain.begin(), maybe_rfactor_domain.end(), - [](const kir::IterDomain* id) { + [](const IterDomain* id) { return !id->isReduction() && - id->iterType() != IterType::BroadcastWithoutStride; + id->getIterType() != IterType::BroadcastWithoutStride; }); code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> " << varName(tv); @@ -177,7 +181,7 @@ class CudaKernelGenerator : private kir::IrVisitor { void genBody() { for (auto expr : kernel_->topLevelExprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } @@ -204,100 +208,93 @@ class CudaKernelGenerator : private kir::IrVisitor { return code_; } - std::string gen(const kir::Node* node) { + std::string gen(const Statement* stmt) { std::stringstream tmp_code; std::swap(tmp_code, code_); - auto replacement = replacement_map_.find(node); + auto replacement = replacement_map_.find(stmt); if (replacement != replacement_map_.end()) { - node = replacement->second; + stmt = replacement->second; } - node->accept(this); + OptOutConstDispatch::handle(stmt); std::swap(tmp_code, code_); return tmp_code.str(); } - // TODO(kir): consider automatic var naming - std::string varName(const kir::Val* val) { - std::string prefix = ""; - if (val->isA()) { - prefix = "T"; - } else { - prefix = typePrefix(val->dtype()); - } - - std::stringstream value_name; - if (val->name() != kInvalidStmName) { - value_name << prefix << val->name(); + std::string varName(const Val* val) { + std::stringstream name; + if (val->isA()) { + name << "T"; } else { - value_name << "k" << prefix << val->id(); + name << typePrefix(val->dtype()); } - return value_name.str(); + name << val->name(); + return name.str(); } - std::string genInline(const kir::Node* node) { + std::string genInline(const Statement* stmt) { const bool saved_inline = print_inline_; print_inline_ = true; - auto result = gen(node); + auto result = gen(stmt); print_inline_ = saved_inline; // NOLINTNEXTLINE(performance-no-automatic-move) return result; } - void visit(const kir::Predicate* node) final { - TORCH_INTERNAL_ASSERT(node->hasValue()); - code_ << gen(node->value()); + void handle(const kir::Predicate* pred) final { + TORCH_INTERNAL_ASSERT(pred->hasValue()); + code_ << gen(pred->value()); } - void visit(const kir::Bool* node) final { - const auto def = node->definition(); + void handle(const Bool* pred) final { + const auto def = pred->definition(); if (print_inline_ && def != nullptr) { code_ << "(" << gen(def) << ")"; - } else if (node->isConst()) { - code_ << (*node->value() ? "true" : "false"); + } else if (pred->isConst()) { + code_ << (*pred->value() ? "true" : "false"); } else { - code_ << varName(node); + code_ << varName(pred); } } - void visit(const kir::Double* node) final { - const auto def = node->definition(); + void handle(const Double* d) final { + const auto def = d->definition(); if (print_inline_ && def != nullptr) { code_ << "(" << gen(def) << ")"; - } else if (node->isConst()) { + } else if (d->isConst()) { const int digits = std::numeric_limits::max_digits10; - code_ << std::setprecision(digits) << *node->value(); + code_ << std::setprecision(digits) << *d->value(); } else { - code_ << varName(node); + code_ << varName(d); } } - void visit(const kir::Int* node) final { - const auto def = node->definition(); + void handle(const Int* i) final { + const auto def = i->definition(); if (print_inline_ && def != nullptr) { code_ << "(" << gen(def) << ")"; - } else if (node->isConst()) { - code_ << *node->value(); + } else if (i->isConst()) { + code_ << *i->value(); } else { - code_ << varName(node); + code_ << varName(i); } } - void visit(const kir::NamedScalar* node) final { + void handle(const NamedScalar* ns) final { // dim3 components are unsigned int. Cast to signed integer to // support negative indexing - if (node->getParallelIndex().has_value() || - node->getParallelDim().has_value()) { - code_ << "((nvfuser_index_t)" << node->name() << ")"; + if (ns->getParallelIndex().has_value() || + ns->getParallelDim().has_value()) { + code_ << "((nvfuser_index_t)" << ns->name() << ")"; } else { - code_ << node->name(); + code_ << ns->name(); } } - void visit(const kir::TensorIndex* node) final { - code_ << varName(node->view()) << "["; + void handle(const kir::TensorIndex* ti) final { + code_ << varName(ti->view()) << "["; bool first = true; - for (auto* ind : node->indices()) { + for (auto* ind : ti->indices()) { if (!ind->isZeroInt()) { if (!first) { code_ << " + "; @@ -314,29 +311,29 @@ class CudaKernelGenerator : private kir::IrVisitor { code_ << "]"; } - void visit(const kir::IterDomain* node) final { - TORCH_INTERNAL_ASSERT(false && "Unreachable"); + void handle(const IterDomain*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void visit(const kir::TensorDomain* node) final { - TORCH_INTERNAL_ASSERT(false && "Unreachable"); + void handle(const TensorDomain*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void visit(const kir::TensorView* tv) final { - TORCH_INTERNAL_ASSERT(false && "Unreachable"); + void handle(const TensorView*) final { + TORCH_INTERNAL_ASSERT(false, "Unreachable"); } - void visit(const kir::UnaryOp* node) final { + void handle(const UnaryOp* uop) final { bool is_vector_op = false; size_t vector_word_size = 1; - if (vectorize_scope_ && node->out()->isA()) { - auto ti = node->out()->as(); + if (vectorize_scope_ && uop->out()->isA()) { + auto ti = uop->out()->as(); bool vectorize_op = false; bool misaligned_op = false; - for (auto id : ti->view()->fuserTv()->domain()->domain()) { + for (auto id : ti->view()->domain()->domain()) { if (!isParallelTypeVectorize(id->getParallelType())) { continue; } @@ -358,84 +355,84 @@ class CudaKernelGenerator : private kir::IrVisitor { if (vectorize_op) { TORCH_INTERNAL_ASSERT( - node->operation() == UnaryOpType::Set, + uop->getUnaryOpType() == UnaryOpType::Set, "Cannot vectorize operations that are not sets. ", "Use cache_before and cache_after to store/load with vectorized reads into buffers."); is_vector_op = true; } if (misaligned_op) { - is_vector_op = (node->operation() == UnaryOpType::Set); + is_vector_op = (uop->getUnaryOpType() == UnaryOpType::Set); } - if (is_vector_op && !node->in()->isScalar()) { + if (is_vector_op && !uop->in()->isScalar()) { TORCH_INTERNAL_ASSERT( - node->out()->dtype() == node->in()->dtype(), + uop->out()->dtype() == uop->in()->dtype(), "Vectorized store/load requires input and output datatypes match."); } } if (is_vector_op) { - if (node->in()->isScalar()) { + if (uop->in()->isScalar()) { indent() << "reinterpret_cast<" - << "Array<" << node->out()->dtype() << ", " << vector_word_size + << "Array<" << uop->out()->dtype() << ", " << vector_word_size << ">*>" - << "(&" << gen(node->out()) << ")->set(" << gen(node->in()) + << "(&" << gen(uop->out()) << ")->set(" << gen(uop->in()) << ");\n"; } else { indent() << "*reinterpret_cast<" - << "Array<" << node->out()->dtype() << ", " << vector_word_size + << "Array<" << uop->out()->dtype() << ", " << vector_word_size << ">*>" - << "(&" << gen(node->out()) << ")" + << "(&" << gen(uop->out()) << ")" << " = *reinterpret_cast<" - << "Array<" << node->in()->dtype() << ", " << vector_word_size + << "Array<" << uop->in()->dtype() << ", " << vector_word_size << ">*>" - << "(&" << gen(node->in()) << ");\n"; + << "(&" << gen(uop->in()) << ");\n"; } return; } - if (node->out()->isA()) { - const auto op_type = node->operation(); + if (uop->out()->isA()) { + const auto op_type = uop->getUnaryOpType(); if (auto op = inline_op_str(op_type)) { - indent() << gen(node->out()) << " = " << *op << genInline(node->in()) + indent() << gen(uop->out()) << " = " << *op << genInline(uop->in()) << ";\n"; } return; } if (!print_inline_) { - indent() << gen(node->out()); - if (!node->out()->isScalar() && !node->in()->isScalar()) { + indent() << gen(uop->out()); + if (!uop->out()->isScalar() && !uop->in()->isScalar()) { code_ << "\n"; indent() << kTab; } code_ << " = "; } - const auto op_type = node->operation(); + const auto op_type = uop->getUnaryOpType(); if (auto op = inline_op_str(op_type)) { if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { - code_ << stringifyBooleanOp(op_type) << gen(node->in()); + uop->out()->dtype() == DataType::Bool) { + code_ << stringifyBooleanOp(op_type) << gen(uop->in()); } else { - code_ << *op << gen(node->in()); + code_ << *op << gen(uop->in()); } } else { if (op_type == UnaryOpType::Cast) { const auto cast_str = - cast_func_str({node->in()->dtype(), node->out()->dtype()}); + cast_func_str({uop->in()->dtype(), uop->out()->dtype()}); TORCH_INTERNAL_ASSERT( cast_str.has_value(), "Invalid cast. Input type: ", - node->in()->dtype(), + uop->in()->dtype(), ", output type: ", - node->out()->dtype()); + uop->out()->dtype()); code_ << cast_str.value(); } else { code_ << op_type; if (needFloatSuffix(op_type) && - node->out()->dtype() == DataType::Float) { + uop->out()->dtype() == DataType::Float) { code_ << "f"; } } @@ -444,7 +441,7 @@ class CudaKernelGenerator : private kir::IrVisitor { if (op_type == UnaryOpType::RandLike) { code_ << "rnd"; } else { - code_ << gen(node->in()); + code_ << gen(uop->in()); } code_ << ")"; } @@ -456,7 +453,7 @@ class CudaKernelGenerator : private kir::IrVisitor { std::string genBinaryOp( BinaryOpType op_type, - kir::Val* out, + Val* out, const std::string& lhs, const std::string& rhs) { std::stringstream expr; @@ -485,7 +482,7 @@ class CudaKernelGenerator : private kir::IrVisitor { // If one argument is a tensorview and the other is a scalar, make sure we // cast the scalar to the tensorview type - std::string scalarCast(kir::Val* lhs, kir::Val* rhs) { + std::string scalarCast(Val* lhs, Val* rhs) { // If neither are scalars return if (!((lhs->isScalar() || rhs->isScalar()) && (lhs->isA() || rhs->isA()))) { @@ -520,18 +517,18 @@ class CudaKernelGenerator : private kir::IrVisitor { } // If possible, replace pow with mul. Return true when successful. - bool genPowerWithMul(const kir::BinaryOp* node) { - if (node->operation() != BinaryOpType::Pow) { + bool genPowerWithMul(const BinaryOp* bop) { + if (bop->getBinaryOpType() != BinaryOpType::Pow) { return false; } - auto rhs = node->rhs(); + auto rhs = bop->rhs(); c10::optional exponent; - if (auto val_int = dynamic_cast(rhs)) { + if (auto val_int = dynamic_cast(rhs)) { if (val_int->isConst()) { exponent = val_int->value().value(); } - } else if (auto val_float = dynamic_cast(rhs)) { + } else if (auto val_float = dynamic_cast(rhs)) { if (val_float->isConst()) { auto fp_exp = val_float->value().value(); double int_exp = 0; @@ -550,7 +547,7 @@ class CudaKernelGenerator : private kir::IrVisitor { return false; } - auto lhs = gen(node->lhs()); + auto lhs = gen(bop->lhs()); if (print_inline_) { code_ << lhs << " * " << lhs; @@ -558,8 +555,8 @@ class CudaKernelGenerator : private kir::IrVisitor { code_ << " * " << lhs; } } else { - indent() << gen(node->out()); - if (node->out()->isScalar()) { + indent() << gen(bop->out()); + if (bop->out()->isScalar()) { code_ << " = " << lhs << " * " << lhs; if (exponent.value() == 3) { code_ << " * " << lhs; @@ -579,24 +576,24 @@ class CudaKernelGenerator : private kir::IrVisitor { return true; } - void visit(const kir::BinaryOp* node) final { + void handle(const BinaryOp* bop) final { // Try replacing pow with mul - if (genPowerWithMul(node)) { + if (genPowerWithMul(bop)) { return; } - const auto op_type = node->operation(); + const auto op_type = bop->getBinaryOpType(); if (print_inline_) { // Inline expression: `lhs op rhs` code_ << genBinaryOp( - op_type, node->out(), gen(node->lhs()), gen(node->rhs())); + op_type, bop->out(), gen(bop->lhs()), gen(bop->rhs())); } else { - indent() << gen(node->out()); - if (node->out()->isScalar()) { + indent() << gen(bop->out()); + if (bop->out()->isScalar()) { // Single line: `out = lhs op rhs;` code_ << " = " << genBinaryOp( - op_type, node->out(), gen(node->lhs()), gen(node->rhs())); + op_type, bop->out(), gen(bop->lhs()), gen(bop->rhs())); } else { // Split TensorView expressions across multiple lines: // @@ -605,64 +602,64 @@ class CudaKernelGenerator : private kir::IrVisitor { // op rhs; // - auto cast = scalarCast(node->lhs(), node->rhs()); + auto cast = scalarCast(bop->lhs(), bop->rhs()); if (auto op = inline_op_str(op_type)) { code_ << "\n"; - indent() << kTab << "= " << (node->lhs()->isScalar() ? cast : "") - << gen(node->lhs()) << "\n"; + indent() << kTab << "= " << (bop->lhs()->isScalar() ? cast : "") + << gen(bop->lhs()) << "\n"; indent() << kTab; if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { + bop->out()->dtype() == DataType::Bool) { code_ << stringifyBooleanOp(op_type); } else { code_ << *op; } - code_ << " " << (node->rhs()->isScalar() ? cast : "") - << gen(node->rhs()); + code_ << " " << (bop->rhs()->isScalar() ? cast : "") + << gen(bop->rhs()); } else { - if (integer_op_str(op_type) && isIntegralType(node->out()->dtype())) { + if (integer_op_str(op_type) && isIntegralType(bop->out()->dtype())) { auto int_op = integer_op_str(op_type); code_ << " = " << *int_op << "(\n"; } else { std::stringstream op_str; op_str << op_type; if (needFloatSuffix(op_type) && - node->out()->dtype() == DataType::Float) { + bop->out()->dtype() == DataType::Float) { op_str << "f"; } code_ << " = " << op_str.str() << "(\n"; } - indent() << kTab << (node->lhs()->isScalar() ? cast : "") - << gen(node->lhs()) << ",\n"; - indent() << kTab << (node->rhs()->isScalar() ? cast : "") - << gen(node->rhs()) << ")"; + indent() << kTab << (bop->lhs()->isScalar() ? cast : "") + << gen(bop->lhs()) << ",\n"; + indent() << kTab << (bop->rhs()->isScalar() ? cast : "") + << gen(bop->rhs()) << ")"; } } code_ << ";\n"; } } - void visit(const kir::TernaryOp* node) final { + void handle(const TernaryOp* top) final { if (!print_inline_) { - indent() << gen(node->out()); - if (!node->out()->isScalar()) { + indent() << gen(top->out()); + if (!top->out()->isScalar()) { code_ << "\n"; indent() << kTab; } code_ << " = "; } - code_ << node->operation() << "(" << gen(node->in1()) << ", "; + code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", "; // Make sure the two operands of where has the same // type. Note that compiling "where(0.0f, 0.0)" fails because of // the overloading ambiguity. - if (node->operation() == TernaryOpType::Where) { - auto cast = scalarCast(node->in2(), node->in3()); - code_ << (node->in2()->isScalar() ? cast : "") << gen(node->in2()) << ", " - << (node->in3()->isScalar() ? cast : "") << gen(node->in3()) << ")"; + if (top->getTernaryOpType() == TernaryOpType::Where) { + auto cast = scalarCast(top->in2(), top->in3()); + code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << ", " + << (top->in3()->isScalar() ? cast : "") << gen(top->in3()) << ")"; } else { - code_ << gen(node->in2()) << ", " << gen(node->in3()) << ")"; + code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")"; } if (!print_inline_) { @@ -670,7 +667,7 @@ class CudaKernelGenerator : private kir::IrVisitor { } } - std::string genReductionOp(BinaryOpType op_type, kir::Val* out) { + std::string genReductionOp(BinaryOpType op_type, Val* out) { std::stringstream lambda; DataType data_type = out->dtype(); lambda << "[](" << data_type << " &a, " << data_type << " b) " @@ -678,47 +675,45 @@ class CudaKernelGenerator : private kir::IrVisitor { return lambda.str(); } - void visit(const kir::BroadcastOp* node) final { - TORCH_INTERNAL_ASSERT(node->out()->isA()); - const auto tensor_index = node->out()->as(); + void handle(const BroadcastOp* stmt) final { + TORCH_INTERNAL_ASSERT(stmt->out()->isA()); + const auto tensor_index = stmt->out()->as(); - const ParallelTypeBitmap domains = - kernel_->predicateMap().getParallelBroadcastDomains( - tensor_index->view()->fuserTv()); + const ParallelTypeBitmap parallel_types = + kernel_->summary().broadcast_parallel_types.at(stmt); - const bool thread_x = domains.get(ParallelType::TIDx); - const bool thread_y = domains.get(ParallelType::TIDy); - const bool thread_z = domains.get(ParallelType::TIDz); - const bool block_x = domains.get(ParallelType::BIDx); - const bool block_y = domains.get(ParallelType::BIDy); - const bool block_z = domains.get(ParallelType::BIDz); - - const bool grid_broadcast_needed = block_x || block_y || block_z; - const bool block_broadcast_needed = thread_x || thread_y || thread_z; + if (parallel_types.none()) { + // Not parallelized + indent() << gen(stmt->out()) << "\n"; + indent() << kTab << " = " << gen(stmt->in()) << ";\n"; + return; + } TORCH_INTERNAL_ASSERT( - !grid_broadcast_needed, - "Parallel broadcast across blocks not supported"); - - if (block_broadcast_needed) { - const auto data_type = node->out()->dtype(); - indent() << "broadcast::blockBroadcast<" << (thread_x ? "true" : "false") - << ", " << (thread_y ? "true" : "false") << ", " - << (thread_z ? "true" : "false") << ">(\n"; - indent() << kTab << gen(node->out()) << ",\n"; - indent() << kTab << gen(node->in()) << ",\n"; - indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; - TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - indent() << kTab << genInline(node->predicate()) << ");\n"; - } else { - indent() << gen(node->out()) << "\n"; - indent() << kTab << " = " << gen(node->in()) << ";\n"; + !parallel_types.hasBID(), + "Parallel broadcast across blocks should have been translated to a GridBroadcast IR node"); + + std::stringstream flags_str; + for (const ParallelType pt : kParallelTypeTIDs) { + const bool parallel_bcast = parallel_types.get(pt); + if (pt != kParallelTypeTIDs[0]) { + flags_str << ", "; + } + flags_str << (parallel_bcast ? "true" : "false"); } + + const auto data_type = stmt->out()->dtype(); + indent() << "broadcast::blockBroadcast<" << flags_str.str() << ">(\n"; + indent() << kTab << gen(stmt->out()) << ",\n"; + indent() << kTab << gen(stmt->in()) << ",\n"; + indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; + TORCH_INTERNAL_ASSERT( + stmt->predicate() != nullptr && stmt->predicate()->hasValue()); + indent() << kTab << genInline(stmt->predicate()) << ");\n"; } void genWarpReductionOp( - const kir::ReductionOp* node, + const ReductionOp* rop, const IterDomain* reduction_id) { bool is_single_warp = kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp; @@ -729,24 +724,25 @@ class CudaKernelGenerator : private kir::IrVisitor { } else { code_ << "(\n"; } - indent() << kTab << gen(node->out()) << ",\n"; - indent() << kTab << gen(node->in()) << ",\n"; - indent() << kTab << genReductionOp(node->operation(), node->out()) << ",\n"; + indent() << kTab << gen(rop->out()) << ",\n"; + indent() << kTab << gen(rop->in()) << ",\n"; + indent() << kTab << genReductionOp(rop->getReductionOpType(), rop->out()) + << ",\n"; indent() << kTab << "threadIdx,\n"; indent() << kTab << "blockDim,\n"; - indent() << kTab << "static_cast<" << node->out()->dtype() + indent() << kTab << "static_cast<" << rop->out()->dtype() << "*>(shared_mem),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - indent() << kTab << genInline(node->predicate()) << ",\n"; - indent() << kTab << node->out()->dtype() << "(" << genInline(node->init()) + rop->predicate() != nullptr && rop->predicate()->hasValue()); + indent() << kTab << genInline(rop->predicate()) << ",\n"; + indent() << kTab << rop->out()->dtype() << "(" << genInline(rop->init()) << "));\n"; } - void visit(const kir::ReductionOp* node) final { - TORCH_INTERNAL_ASSERT(node->out()->isA()); + void handle(const ReductionOp* rop) final { + TORCH_INTERNAL_ASSERT(rop->out()->isA()); - const auto out = node->out()->as(); + const auto out = rop->out()->as(); const auto domain = out->view()->domain(); const bool has_block_reduce = domain->hasBlockReduction(); @@ -754,18 +750,18 @@ class CudaKernelGenerator : private kir::IrVisitor { if (!has_block_reduce && !has_grid_reduce) { const auto gen_out = gen(out); - const auto op_type = node->operation(); + const auto op_type = rop->getReductionOpType(); indent() << gen_out << " = " - << genBinaryOp(op_type, out, gen_out, gen(node->in())) << ";\n"; + << genBinaryOp(op_type, out, gen_out, gen(rop->in())) << ";\n"; return; } - if (auto reduction_id = ir_utils::getMaybeWarpReductionDim(node)) { - genWarpReductionOp(node, reduction_id.value()); + if (auto reduction_id = ir_utils::getMaybeWarpReductionDim(rop)) { + genWarpReductionOp(rop, reduction_id.value()); return; } - const auto par_domains = ir_utils::getParallelDomains(node->out()); + const auto par_domains = ir_utils::getParallelDomains(rop->out()); // Get parallel reduction domains const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end() && @@ -777,14 +773,14 @@ class CudaKernelGenerator : private kir::IrVisitor { par_domains.find(ParallelType::TIDz) != par_domains.end() && par_domains.at(ParallelType::TIDz)->isReduction(); - const auto data_type = node->out()->dtype(); - const auto op_type = node->operation(); + const auto data_type = rop->out()->dtype(); + const auto op_type = rop->getReductionOpType(); if (has_block_reduce) { if (has_grid_reduce) { indent() << data_type << " " << "block_result_" << block_reduce_name_ << "=" - << gen(node->init()) << ";\n"; + << gen(rop->init()) << ";\n"; } indent() << "blockReduce<" << (tidx ? "true" : "false") << ", " << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false") @@ -792,44 +788,43 @@ class CudaKernelGenerator : private kir::IrVisitor { if (has_grid_reduce) { indent() << kTab << "block_result_" << block_reduce_name_ << ",\n"; } else { - indent() << kTab << gen(node->out()) << ",\n"; + indent() << kTab << gen(rop->out()) << ",\n"; } - indent() << kTab << gen(node->in()) << ",\n"; - indent() << kTab << genReductionOp(op_type, node->out()) << ",\n"; + indent() << kTab << gen(rop->in()) << ",\n"; + indent() << kTab << genReductionOp(op_type, rop->out()) << ",\n"; indent() << kTab << "threadIdx,\n"; indent() << kTab << "blockDim,\n"; indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + rop->predicate() != nullptr && rop->predicate()->hasValue()); + auto read_pred = genInline(rop->predicate()); indent() << kTab << read_pred << ",\n"; // Pass the write predicate if available and different from the // default predicate. The blockReduce runtime function uses the // default predicate for both read and write when only the // default one is given. - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (rop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(rop->writePredicate()->hasValue()); + auto write_pred = genInline(rop->writePredicate()); indent() << kTab << write_pred << ",\n"; } - indent() << kTab << data_type << "(" << genInline(node->init()) - << "));\n"; + indent() << kTab << data_type << "(" << genInline(rop->init()) << "));\n"; } } - void visit(const kir::WelfordOp* node) final { - TORCH_INTERNAL_ASSERT(node->out()->isA()); + void handle(const WelfordOp* wop) final { + TORCH_INTERNAL_ASSERT(wop->out()->isA()); - const auto out = node->out()->as(); + const auto out = wop->out()->as(); const auto domain = out->view()->domain(); - const auto out_var = node->outVar(); - const auto out_avg = node->outAvg(); - const auto out_N = node->outN(); + const auto out_var = wop->outVar(); + const auto out_avg = wop->outAvg(); + const auto out_N = wop->outN(); - const auto in_var = node->inVar(); - const auto in_avg = node->inAvg(); - const auto in_N = node->inN(); + const auto in_var = wop->inVar(); + const auto in_avg = wop->inAvg(); + const auto in_N = wop->inN(); const bool has_block_reduce = domain->hasBlockReduction(); const bool has_grid_reduce = domain->hasGridReduction(); @@ -852,7 +847,7 @@ class CudaKernelGenerator : private kir::IrVisitor { return; } - const auto par_domains = ir_utils::getParallelDomains(node->out()); + const auto par_domains = ir_utils::getParallelDomains(wop->out()); // Get parallel reduction domains const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end() && @@ -864,20 +859,20 @@ class CudaKernelGenerator : private kir::IrVisitor { par_domains.find(ParallelType::TIDz) != par_domains.end() && par_domains.at(ParallelType::TIDz)->isReduction(); - const auto data_type = node->out()->dtype(); + const auto data_type = wop->out()->dtype(); if (has_block_reduce) { if (has_grid_reduce) { // allocate block result indent() << data_type << " " << "block_result_avg_" << block_reduce_name_ << " = " - << gen(node->initAvg()) << ";\n"; + << gen(wop->initAvg()) << ";\n"; indent() << data_type << " " << "block_result_var_" << block_reduce_name_ << " = " - << gen(node->initVar()) << ";\n"; + << gen(wop->initVar()) << ";\n"; indent() << DataType::Int << " " << "block_result_n_" << block_reduce_name_ << " = " - << gen(node->initN()) << ";\n"; + << gen(wop->initN()) << ";\n"; } indent() << "blockWelford<" << (tidx ? "true" : "false") << ", " << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false") @@ -887,9 +882,9 @@ class CudaKernelGenerator : private kir::IrVisitor { << kTab << "block_result_var_" << block_reduce_name_ << ",\n" << kTab << "block_result_n_" << block_reduce_name_ << ",\n"; } else { - indent() << kTab << gen(node->outAvg()) << ",\n"; - indent() << kTab << gen(node->outVar()) << ",\n"; - indent() << kTab << gen(node->outN()) << ",\n"; + indent() << kTab << gen(wop->outAvg()) << ",\n"; + indent() << kTab << gen(wop->outVar()) << ",\n"; + indent() << kTab << gen(wop->outN()) << ",\n"; } indent() << " " << gen(in_avg) << ",\n"; if (in_var) { @@ -907,14 +902,14 @@ class CudaKernelGenerator : private kir::IrVisitor { << "*>(shared_mem_var),\n"; indent() << kTab << "reinterpret_cast<" << DataType::Int << "*>(shared_mem_n),\n"; - TORCH_INTERNAL_ASSERT(node->predicate() != nullptr); + TORCH_INTERNAL_ASSERT(wop->predicate() != nullptr); TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + wop->predicate() != nullptr && wop->predicate()->hasValue()); + auto read_pred = genInline(wop->predicate()); indent() << kTab << read_pred << ",\n"; - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (wop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(wop->writePredicate()->hasValue()); + auto write_pred = genInline(wop->writePredicate()); indent() << kTab << write_pred << ",\n"; } indent() << kTab << data_type << "(0));\n"; @@ -954,8 +949,8 @@ class CudaKernelGenerator : private kir::IrVisitor { return flags.str(); } - void visit(const kir::GridReduction* node) final { - const auto rop = node->reduction_op(); + void handle(const kir::GridReduction* grop) final { + const auto rop = grop->reduction_op(); TORCH_INTERNAL_ASSERT(rop->out()->isA()); const auto out = rop->out()->as(); @@ -963,19 +958,17 @@ class CudaKernelGenerator : private kir::IrVisitor { TORCH_INTERNAL_ASSERT(domain->hasGridReduction()); const auto data_type = rop->out()->dtype(); - const auto op_type = rop->operation(); + const auto op_type = rop->getReductionOpType(); TORCH_INTERNAL_ASSERT( - node->reduction_buffer()->buffer()->isA()); - TORCH_INTERNAL_ASSERT( - node->sync_buffer()->buffer()->isA()); + grop->reduction_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA()); const auto work_buffer = - node->reduction_buffer()->buffer()->as(); - const auto sync_buffer = - node->sync_buffer()->buffer()->as(); + grop->reduction_buffer()->buffer()->as(); + const auto sync_buffer = grop->sync_buffer()->buffer()->as(); const std::string flags_str = - generateGridReduceTemplateFlags(rop, node->threadPredicate()); + generateGridReduceTemplateFlags(rop, grop->threadPredicate()); const bool persistent_sync = kernel_->summary().has_cooperative_grid_reduction; @@ -996,44 +989,46 @@ class CudaKernelGenerator : private kir::IrVisitor { indent() << kTab << varName(sync_buffer) << ",\n"; indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + grop->predicate() != nullptr && grop->predicate()->hasValue()); + auto read_pred = genInline(grop->predicate()); indent() << kTab << read_pred << ",\n"; - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (grop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue()); + auto write_pred = genInline(grop->writePredicate()); indent() << kTab << write_pred << ",\n"; } else { indent() << kTab << read_pred << ",\n"; } indent() << kTab << data_type << "(" - << genInline(node->reduction_op()->init()) << "));\n"; + << genInline(grop->reduction_op()->init()) << "));\n"; } - void visit(const kir::GridBroadcast* node) final { - const auto bop = node->broadcast_op(); + void handle(const kir::GridBroadcast* grop) final { + const auto bop = grop->broadcast_op(); TORCH_INTERNAL_ASSERT(bop->out()->isA()); + const ParallelTypeBitmap parallel_types = + kernel_->summary().broadcast_parallel_types.at(bop); + + TORCH_INTERNAL_ASSERT( + parallel_types.hasBID(), + "GridBroadcast needs to be used with a broadcast op that is parallelized with the BID parallel types"); + const auto out = bop->out()->as(); const auto domain = out->view()->domain(); - TORCH_INTERNAL_ASSERT(domain->hasGridBroadcast()); const auto data_type = bop->out()->dtype(); TORCH_INTERNAL_ASSERT( - node->broadcast_buffer()->buffer()->isA()); - TORCH_INTERNAL_ASSERT( - node->sync_buffer()->buffer()->isA()); + grop->broadcast_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA()); const auto work_buffer = - node->broadcast_buffer()->buffer()->as(); - const auto sync_buffer = - node->sync_buffer()->buffer()->as(); + grop->broadcast_buffer()->buffer()->as(); + const auto sync_buffer = grop->sync_buffer()->buffer()->as(); - const auto par_domains = ir_utils::getParallelDomains(out); std::stringstream flags_str; for (const ParallelType pt : kParallelTypeThreads) { - const bool parallel_bcast = par_domains.find(pt) != par_domains.end() && - par_domains.at(pt)->isBroadcast(); + const bool parallel_bcast = parallel_types.get(pt); if (pt != kParallelTypeThreads[0]) { flags_str << ", "; } @@ -1041,7 +1036,7 @@ class CudaKernelGenerator : private kir::IrVisitor { } // Since block-level broadcast has not necessarily been performed before - // this function call, so grid broadcast may be broadcasting across both + // this function call, so grid broadcast may be broadcasting across both // the grid and the block level. indent() << "grid_broadcast::broadcast<" << flags_str.str() << ">(\n"; indent() << kTab << gen(bop->out()) << ",\n"; @@ -1049,12 +1044,12 @@ class CudaKernelGenerator : private kir::IrVisitor { indent() << kTab << "&" << varName(work_buffer) << "[0],\n"; indent() << kTab << varName(sync_buffer) << ",\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - indent() << kTab << genInline(node->predicate()) << ");\n"; + grop->predicate() != nullptr && grop->predicate()->hasValue()); + indent() << kTab << genInline(grop->predicate()) << ");\n"; } - void visit(const kir::GridWelford* node) final { - const auto wop = node->welford_op(); + void handle(const kir::GridWelford* gwop) final { + const auto wop = gwop->welford_op(); TORCH_INTERNAL_ASSERT(wop->outAvg()->isA()); const auto out = wop->out()->as(); @@ -1063,21 +1058,19 @@ class CudaKernelGenerator : private kir::IrVisitor { const auto data_type = out->dtype(); - TORCH_INTERNAL_ASSERT(node->var_buffer()->buffer()->isA()); - TORCH_INTERNAL_ASSERT( - node->sync_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(gwop->var_buffer()->buffer()->isA()); + TORCH_INTERNAL_ASSERT(gwop->sync_buffer()->buffer()->isA()); - const auto avg_buffer = node->avg_buffer()->buffer()->as(); - const auto var_buffer = node->var_buffer()->buffer()->as(); - const auto n_buffer = node->N_buffer()->buffer()->as(); - const auto sync_buffer = - node->sync_buffer()->buffer()->as(); + const auto avg_buffer = gwop->avg_buffer()->buffer()->as(); + const auto var_buffer = gwop->var_buffer()->buffer()->as(); + const auto n_buffer = gwop->N_buffer()->buffer()->as(); + const auto sync_buffer = gwop->sync_buffer()->buffer()->as(); const bool persistent_sync = kernel_->summary().has_cooperative_grid_reduction; const std::string flags_str = - generateGridReduceTemplateFlags(wop, node->threadPredicate()); + generateGridReduceTemplateFlags(wop, gwop->threadPredicate()); // Since block-level reduction is already done, those dimensions // with tidx/y/z being true do not participate in the grid reduction. @@ -1112,12 +1105,12 @@ class CudaKernelGenerator : private kir::IrVisitor { indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype() << "*>(shared_mem_n),\n"; TORCH_INTERNAL_ASSERT( - node->predicate() != nullptr && node->predicate()->hasValue()); - auto read_pred = genInline(node->predicate()); + gwop->predicate() != nullptr && gwop->predicate()->hasValue()); + auto read_pred = genInline(gwop->predicate()); indent() << kTab << read_pred << ",\n"; - if (node->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue()); - auto write_pred = genInline(node->writePredicate()); + if (gwop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue()); + auto write_pred = genInline(gwop->writePredicate()); indent() << kTab << write_pred << ",\n"; } else { indent() << kTab << read_pred << ",\n"; @@ -1128,27 +1121,26 @@ class CudaKernelGenerator : private kir::IrVisitor { void handleScope(const kir::Scope& scope) { for (auto expr : scope.exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } - void visit(const kir::ForLoop* node) final { - // TODO(kir): handle this during lowering - if (node->iter_domain()->isBroadcast()) { - handleScope(node->body()); + void handle(const kir::ForLoop* loop) final { + if (loop->iter_domain()->isBroadcast()) { + handleScope(loop->body()); return; - } else if (node->vectorize()) { - vectorize_scope_ = node->vectorize(); - handleScope(node->body()); + } else if (loop->vectorize()) { + vectorize_scope_ = loop->vectorize(); + handleScope(loop->body()); vectorize_scope_ = false; return; - } else if (node->iter_domain()->isStride()) { + } else if (loop->iter_domain()->isStride()) { // A stride domain only executes the loop body with the loop // index being zero. indent() << "constexpr " << "nvfuser_index_t" - << " " << gen(node->index()) << " = 0;\n"; - handleScope(node->body()); + << " " << gen(loop->index()) << " = 0;\n"; + handleScope(loop->body()); return; } @@ -1168,56 +1160,82 @@ class CudaKernelGenerator : private kir::IrVisitor { // necessary since the loop stop value just needs to be <= the // IterDomain extent. However, at this point, this conservative // analysis seems sufficient. - if (node->stop() == node->iter_domain()->extent() && - node->iter_domain()->isThread()) { + if (loop->stop() == loop->iter_domain()->extent() && + loop->iter_domain()->isThread()) { // Register a replacement of references to the loop index with // the loop start value. - replacement_map_.insert({node->index(), node->start()}); - handleScope(node->body()); - replacement_map_.erase(node->index()); + replacement_map_.insert({loop->index(), loop->start()}); + handleScope(loop->body()); + replacement_map_.erase(loop->index()); return; } - if (node->start()->isZeroInt() && node->stop()->isOneInt()) { + if (loop->start()->isZeroInt() && loop->stop()->isOneInt()) { indent() << "constexpr " << "nvfuser_index_t" - << " " << gen(node->index()) << " = 0;\n"; - handleScope(node->body()); + << " " << gen(loop->index()) << " = 0;\n"; + handleScope(loop->body()); + return; + } else if ( + // Special case handling for a pattern where start == end - 1. + loop->start()->definition() != nullptr && + loop->start()->definition()->isA() && + loop->start()->definition()->as()->getBinaryOpType() == + BinaryOpType::Sub && + loop->start()->definition()->as()->lhs() == loop->stop() && + loop->start()->definition()->as()->rhs()->isOneInt()) { + indent() << "const " + << "nvfuser_index_t" + << " " << gen(loop->index()) << " = " << genInline(loop->start()) + << ";\n"; + handleScope(loop->body()); return; } - const auto gen_index = gen(node->index()); - const auto gen_start = genInline(node->start()); - const auto gen_stop = genInline(node->stop()); - const auto gen_step = genInline(node->step()); + const auto gen_index = gen(loop->index()); + const auto gen_start = genInline(loop->start()); + const auto gen_stop = genInline(loop->stop()); + const auto gen_step = genInline(loop->step()); std::stringstream step_code; - if (node->step()->isOneInt()) { + if (loop->step()->isOneInt()) { step_code << "++" << gen_index; } else { step_code << gen_index << " += " << gen_step; } - if (node->isUnrolled()) { + if (loop->isUnrolled()) { indent() << "#pragma unroll\n"; } else { indent() << "#pragma unroll 1\n"; } - indent() << "for(nvfuser_index_t " << gen_index << " = " << gen_start - << "; " << gen_index << " < " << gen_stop << "; " - << step_code.str() << ") "; + + indent() << "for(nvfuser_index_t " << gen_index; + if (loop->iter_domain()->isParallelized()) { + code_ << " = " << gen_start << "; "; + } else { + // Do not start at the start of the ID when not parallelized. Instead, + // start at 0. Predicates will protect buffers between 0 and ID->start(), + // however if we started at ID->start and extent == ID->start, we could + // have a "degenerate" loop (loop with no iterations). It may not be an + // issue to have a 0-sized loop, but all potential consequences haven't + // been covered. One example is WAR analysis which could incorrectly think + // a barrier inside a 0-sized loop actually provides protection. + code_ << " = 0; "; + } + code_ << gen_index << " < " << gen_stop << "; " << step_code.str() << ") "; startBlock(true); - handleScope(node->body()); + handleScope(loop->body()); endBlock(); } - void visit(const kir::IfThenElse* node) final { - auto conditional = node->predicate()->value(); + void handle(const kir::IfThenElse* ite) final { + auto conditional = ite->predicate()->value(); if (conditional->isConst()) { // If the conditional is a constant, then the IfThenElse is not required if (conditional->value().value()) { - handleScope(node->thenBody()); + handleScope(ite->thenBody()); } else { - handleScope(node->elseBody()); + handleScope(ite->elseBody()); } return; } @@ -1226,41 +1244,40 @@ class CudaKernelGenerator : private kir::IrVisitor { // "then" block startBlock(true); - handleScope(node->thenBody()); + handleScope(ite->thenBody()); // "else" block (optional) - if (node->hasElse()) { + if (ite->hasElse()) { endBlock(" else "); startBlock(true); - handleScope(node->elseBody()); + handleScope(ite->elseBody()); } endBlock(); } - // TODO(kir): fold initialization into Allocate - void visit(const kir::Allocate* node) final { - const auto buffer_dtype = node->buffer()->dtype(); + void handle(const kir::Allocate* alloc) final { + const auto buffer_dtype = alloc->buffer()->dtype(); - if (!node->buffer()->isA()) { - indent() << buffer_dtype << " " << gen(node->buffer()) << ";\n"; + if (!alloc->buffer()->isA()) { + indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n"; return; } - const auto tv = node->buffer()->as(); + const auto tv = alloc->buffer()->as(); - const auto size = node->size(); + const auto size = alloc->size(); TORCH_INTERNAL_ASSERT(size != nullptr); - if (node->alias() != nullptr) { - // Allocate alias another Allocate node - const auto alias_tv = node->alias()->buffer()->as(); - indent() << "// Alias Allocation - " << node->memoryType() << "\n"; + if (alloc->alias() != nullptr) { + // Allocate alias another Allocate stmt + const auto alias_tv = alloc->alias()->buffer()->as(); + indent() << "// Alias Allocation - " << alloc->memoryType() << "\n"; indent() << buffer_dtype << "* " << varName(tv) << " = " << varName(alias_tv) << ";\n"; } else { // Standard Memory Allocation - switch (tv->memoryType()) { + switch (tv->getMemoryType()) { case MemoryType::Global: indent() << "// Allocate global tensor " << varName(tv) << "\n"; break; @@ -1292,7 +1309,7 @@ class CudaKernelGenerator : private kir::IrVisitor { } } - void visit(const kir::Sync* node) final { + void handle(const kir::Sync*) final { // Use a custom synchronization method if enabled if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { indent() << "block_sync::sync();\n"; @@ -1301,11 +1318,11 @@ class CudaKernelGenerator : private kir::IrVisitor { } } - void visit(const kir::InitMagicZero* node) final { + void handle(const kir::InitMagicZero*) final { indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; } - void visit(const kir::UpdateMagicZero* node) final { + void handle(const kir::UpdateMagicZero*) final { indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; } @@ -1314,15 +1331,13 @@ class CudaKernelGenerator : private kir::IrVisitor { const kir::Kernel* kernel_; int block_nest_level_ = 0; int block_reduce_name_ = 0; - - // TODO(kir): replace with explicit assignment statements bool print_inline_ = false; // Mark when we are inside of a vectorized for-loop bool vectorize_scope_ = false; //! Holds active replacement mappings during codegen - std::unordered_map replacement_map_; + std::unordered_map replacement_map_; }; } // namespace diff --git a/torch/csrc/jit/codegen/cuda/codegen.h b/torch/csrc/jit/codegen/cuda/codegen.h index 2ffbb872155..31e4fb70736 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.h +++ b/torch/csrc/jit/codegen/cuda/codegen.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 45f744d7e2f..f51e0fe1bc9 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -59,14 +59,8 @@ bool validateDomain(TensorView* tv, TensorDomain* new_td) { unsigned int getReplayablePosPasC( TensorView* producer, TensorView* consumer, - const ComputeAtRootDomainMap& root_map_, + const std::unordered_set& unmappable_producer_dims, ComputeAtMode mode) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline reduction structures. - auto mappable_roots = - root_map_.getMappableDims(producer->domain(), consumer->domain()); - // Check if any consumer dimensions are marked as vectorize as producer can // not be inlined to vectorized dimensions in consumer. auto c_dom = consumer->domain()->domain(); @@ -124,9 +118,14 @@ unsigned int getReplayablePosPasC( if (std::any_of( consumer_root_dim_ids.begin(), consumer_root_dim_ids.end(), - [&mappable_roots, &c2p_root_map](IterDomain* root_id) { - return mappable_roots.find(root_id) == mappable_roots.end() && - c2p_root_map.find(root_id) != c2p_root_map.end(); + [&unmappable_producer_dims, &c2p_root_map](IterDomain* c_root_id) { + auto p_root_id_it = c2p_root_map.find(c_root_id); + if (p_root_id_it == c2p_root_map.end()) { + return false; + } + auto p_id = p_root_id_it->second; + return unmappable_producer_dims.find(p_id) != + unmappable_producer_dims.end(); })) { continue; } @@ -146,14 +145,8 @@ unsigned int getReplayablePosPasC( unsigned int getReplayablePosCasP( TensorView* consumer, TensorView* producer, - const ComputeAtRootDomainMap& root_map_, + const std::unordered_set& unmappable_producer_dims, ComputeAtMode mode) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline reduction structures. - auto mappable_roots = - root_map_.getMappableDims(producer->domain(), consumer->domain()); - auto p_dom = producer->domain()->domain(); auto first_reduction = std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { @@ -208,10 +201,11 @@ unsigned int getReplayablePosCasP( if (std::any_of( producer->getMaybeRFactorDomain().begin(), producer->getMaybeRFactorDomain().end(), - [&mappable_roots, &all_vals](IterDomain* root_id) { - return std::find(all_vals.begin(), all_vals.end(), root_id) != + [&unmappable_producer_dims, &all_vals](IterDomain* p_root_id) { + return std::find(all_vals.begin(), all_vals.end(), p_root_id) != all_vals.end() && - mappable_roots.find(root_id) == mappable_roots.end(); + unmappable_producer_dims.find(p_root_id) != + unmappable_producer_dims.end(); })) { continue; } @@ -446,7 +440,8 @@ unsigned int ComputeAt::backwardComputeAt_impl( FUSER_PERF_SCOPE("backwardComputeAt_impl"); auto max_consumer_compute_at_pos = - getReplayablePosPasC(producer, consumer, root_map_, mode_); + getReplayablePosPasC(producer, consumer, unmappable_dims_, mode_); + if (mode_ == ComputeAtMode::BestEffort) { consumer_compute_at_pos = std::min(consumer_compute_at_pos, max_consumer_compute_at_pos); @@ -517,7 +512,7 @@ unsigned int ComputeAt::forwardComputeAt_impl( FUSER_PERF_SCOPE("forwardComputeAt_impl"); auto max_producer_compute_at_pos = - getReplayablePosCasP(consumer, producer, root_map_, mode_); + getReplayablePosCasP(consumer, producer, unmappable_dims_, mode_); if (mode_ == ComputeAtMode::BestEffort) { producer_compute_at_pos = @@ -865,6 +860,25 @@ void ComputeAt::runPass() { } } +void ComputeAt::buildUnmappableDims() { + auto all_tvs = ir_utils::allTvs(producer_->fusion()); + for (auto tv : all_tvs) { + auto consumers = ir_utils::consumerTvsOf(tv); + for (auto consumer : consumers) { + // Grab dimensions in producer and consumer that are mappable to eachother + // based on the computeAtRootDomainMap. This will tell us which dimensions + // can be inlined based on avoiding trying to inline reduction structures. + auto mappable_roots = + root_map_.getMappableDims(tv->domain(), consumer->domain()); + for (auto tv_root_id : tv->getMaybeRFactorDomain()) { + if (mappable_roots.find(tv_root_id) == mappable_roots.end()) { + unmappable_dims_.emplace(tv_root_id); + } + } + } + } +} + ComputeAt::ComputeAt( TensorView* _producer, TensorView* _consumer, @@ -903,6 +917,8 @@ ComputeAt::ComputeAt( setCommonConsumer(); root_map_.build(); + + buildUnmappableDims(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 391225218db..75fca5705ed 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -2,11 +2,12 @@ #include +#include #include -#include #include #include +#include #include namespace torch { @@ -68,6 +69,10 @@ class ComputeAt { // call. void setCommonConsumer(); + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + // Propagate backward from consumer to producer, check if it increase // computeAt position on tensors, if so take it! void traverseBackward(); @@ -106,6 +111,9 @@ class ComputeAt { // Producer use chains set in, used in a few spots. std::deque> producer_use_chains_; + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims_; + ComputeAt( TensorView* _producer, TensorView* _consumer, diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 6671fc37546..f46a7495163 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include @@ -488,71 +487,6 @@ void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) { } } } - - if (gpu_lower != nullptr) { - convertToKir(fusion, gpu_lower); - } -} - -void ComputeAtMap::convertToKir(Fusion* fusion, GpuLower* gpu_lower) { - TORCH_INTERNAL_ASSERT(fusion != nullptr); - TORCH_INTERNAL_ASSERT(gpu_lower != nullptr); - - has_lowered_kir_ = true; - - std::unordered_map< - std::shared_ptr>, - std::shared_ptr>> - disjoint_set_2_kir; - - for (const auto& disjoint_iter_set : disjoint_iter_set_maps_) { - auto fusion_set = disjoint_iter_set.second; - auto kir_set_it = disjoint_set_2_kir.find(fusion_set); - std::shared_ptr> kir_set; - if (kir_set_it == disjoint_set_2_kir.end()) { - kir_set = std::make_shared>(); - std::transform( - fusion_set->begin(), - fusion_set->end(), - std::inserter(*kir_set, kir_set->begin()), - [&gpu_lower](IterDomain* id) { - return gpu_lower->lowerValue(id)->as(); - }); - disjoint_set_2_kir.emplace(std::make_pair(fusion_set, kir_set)); - } else { - kir_set = kir_set_it->second; - } - kir_disjoint_iter_set_maps_.emplace(std::make_pair( - gpu_lower->lowerValue(disjoint_iter_set.first)->as(), - kir_set)); - } - - for (auto entry : concrete_id_map_) { - kir_concrete_id_map_.emplace(std::make_pair( - gpu_lower->lowerValue(entry.first)->as(), - gpu_lower->lowerValue(entry.second)->as())); - } - - for (const auto& entry : disjoint_iter_set_maps_) { - kir_2_fusion_[gpu_lower->lowerValue(entry.first)->as()] = - entry.first; - } - - // Make sure we have all IterDomains that could be used to generate a ForLoop - for (auto expr : fusion->exprs()) { - if (!expr->outputs()[0]->isA()) { - continue; - } - - auto tv_outputs = ir_utils::filterByType(expr->outputs()); - - for (auto out : tv_outputs) { - for (auto entry : out->domain()->domain()) { - kir_2_fusion_[gpu_lower->lowerValue(entry)->as()] = - entry; - } - } - } } bool ComputeAtMap::areMapped(IterDomain* id0, IterDomain* id1) const { @@ -568,20 +502,6 @@ bool ComputeAtMap::areMapped(IterDomain* id0, IterDomain* id1) const { return (set0_it->second.get() == set1_it->second.get()); } -bool ComputeAtMap::areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const { - assertLowered(has_lowered_kir_); - if (id0 == id1) { - return true; - } - auto set0_it = kir_disjoint_iter_set_maps_.find(id0); - auto set1_it = kir_disjoint_iter_set_maps_.find(id1); - if (set0_it == kir_disjoint_iter_set_maps_.end() || - set1_it == kir_disjoint_iter_set_maps_.end()) { - return false; - } - return (set0_it->second.get() == set1_it->second.get()); -} - IterDomain* ComputeAtMap::getConcreteMappedID(IterDomain* id) const { auto it = concrete_id_map_.find(id); if (it != concrete_id_map_.end()) { @@ -590,25 +510,6 @@ IterDomain* ComputeAtMap::getConcreteMappedID(IterDomain* id) const { return id; } -kir::IterDomain* ComputeAtMap::getConcreteMappedID(kir::IterDomain* id) const { - assertLowered(has_lowered_kir_); - auto it = kir_concrete_id_map_.find(id); - if (it != kir_concrete_id_map_.end()) { - return it->second; - } - return id; -} - -IterDomain* ComputeAtMap::toFusion(kir::IterDomain* kir) const { - assertLowered(has_lowered_kir_); - auto kir_2_fusion_it = kir_2_fusion_.find(kir); - TORCH_INTERNAL_ASSERT( - kir_2_fusion_it != kir_2_fusion_.end(), - "Kernel ir is not guarneteed to be reversible into fusion ir, could not find fusion entry. ", - kir::toString(kir, false)); - return kir_2_fusion_it->second; -} - std::string ComputeAtMap::toString() const { std::stringstream ss; diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index b2b70f8997d..8b7f9acd8fe 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -67,34 +67,18 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! same loop nest in the lowered code bool areMapped(IterDomain* id0, IterDomain* id1) const; - bool areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const; - //! Returns an iter domain that is the maximum expanded size of all iter //! domains the one provided maps to. Useful for opening loops to the correct //! iteration size. Not guarenteed to return the same ID every call, but is //! guarenteed to return iter domains in the same disjoint set. IterDomain* getConcreteMappedID(IterDomain* id) const; - kir::IterDomain* getConcreteMappedID(kir::IterDomain* id) const; - - // TODO: Would be great if we didn't need this, but we have nice functionality - // in iter_visitor that isn't moved over. Use of this is limited to indexing - // and this should definitely be removed by building out kernel ir to have - // better parity with fusion ir. - IterDomain* toFusion(kir::IterDomain* kir) const; - // Prints mapping information via Fusion IR std::string toString() const; private: - bool has_lowered_kir_ = false; - void mapIds(IterDomain* id0, IterDomain* id1); - //! Convert everything to lowered structures (kernel ir), as we will use - //! this class frequently during lowering. - void convertToKir(Fusion* fusion, GpuLower* gpu_lower); - private: MappingMode mapping_mode_ = MappingMode::LOOP; @@ -109,11 +93,6 @@ class TORCH_CUDA_CU_API ComputeAtMap { std::unordered_map>> disjoint_iter_set_maps_; - std::unordered_map< - kir::IterDomain*, - std::shared_ptr>> - kir_disjoint_iter_set_maps_; - // Keep a list of disjoint_iter_sets that's deterministic to iterate over std::deque>> disjoint_iter_sets_; @@ -125,12 +104,6 @@ class TORCH_CUDA_CU_API ComputeAtMap { // For each IterDomain set we will track how many concrete root domains were // used to generate the IterDomain std::unordered_map concrete_id_map_; - - std::unordered_map kir_concrete_id_map_; - - // Map kir::IterDomain* back to the fusion IR IterDomain*. - // TODO: Would be great if we didn't need this. - std::unordered_map kir_2_fusion_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index cea8b24e7ff..1702de93bdd 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -37,7 +37,7 @@ T* ptr(T* obj) { * } * * And therefore dispatch should never call: - * ptr(mutator)->handle(this->as()); + * ptr(mutator)->mutate(this->as()); */ template @@ -58,6 +58,10 @@ void Val::dispatch(T handler, Val* val) { break; } break; + case ValType::NamedScalar: + ptr(handler)->handle(val->as()); + return; + case ValType::IterDomain: ptr(handler)->handle(val->as()); return; @@ -67,8 +71,11 @@ void Val::dispatch(T handler, Val* val) { case ValType::TensorView: ptr(handler)->handle(val->as()); return; - case ValType::NamedScalar: - ptr(handler)->handle(val->as()); + case ValType::Predicate: + ptr(handler)->handle(val->as()); + return; + case ValType::TensorIndex: + ptr(handler)->handle(val->as()); return; default: break; @@ -79,12 +86,6 @@ void Val::dispatch(T handler, Val* val) { template void Expr::dispatch(T handler, Expr* expr) { switch (*(expr->getExprType())) { - case ExprType::Split: - ptr(handler)->handle(expr->as()); - return; - case ExprType::Merge: - ptr(handler)->handle(expr->as()); - return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -103,6 +104,13 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Split: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Merge: + ptr(handler)->handle(expr->as()); + return; case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; @@ -115,6 +123,34 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::ViewOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Allocate: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Sync: + ptr(handler)->handle(expr->as()); + return; + case ExprType::InitMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::UpdateMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::ForLoop: + ptr(handler)->handle(expr->as()); + return; + case ExprType::IfThenElse: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridReduction: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridBroadcast: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridWelford: + ptr(handler)->handle(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -148,6 +184,10 @@ void Val::constDispatch(T handler, const Val* val) { break; } break; + case ValType::NamedScalar: + ptr(handler)->handle(val->as()); + return; + case ValType::IterDomain: ptr(handler)->handle(val->as()); return; @@ -157,8 +197,11 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::TensorView: ptr(handler)->handle(val->as()); return; - case ValType::NamedScalar: - ptr(handler)->handle(val->as()); + case ValType::Predicate: + ptr(handler)->handle(val->as()); + return; + case ValType::TensorIndex: + ptr(handler)->handle(val->as()); return; default: break; @@ -169,12 +212,6 @@ void Val::constDispatch(T handler, const Val* val) { template void Expr::constDispatch(T handler, const Expr* expr) { switch (*(expr->getExprType())) { - case ExprType::Split: - ptr(handler)->handle(expr->as()); - return; - case ExprType::Merge: - ptr(handler)->handle(expr->as()); - return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -193,6 +230,13 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::BroadcastOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Split: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Merge: + ptr(handler)->handle(expr->as()); + return; case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; @@ -205,6 +249,34 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::ViewOp: ptr(handler)->handle(expr->as()); return; + + case ExprType::Allocate: + ptr(handler)->handle(expr->as()); + return; + case ExprType::Sync: + ptr(handler)->handle(expr->as()); + return; + case ExprType::InitMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::UpdateMagicZero: + ptr(handler)->handle(expr->as()); + return; + case ExprType::ForLoop: + ptr(handler)->handle(expr->as()); + return; + case ExprType::IfThenElse: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridReduction: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridBroadcast: + ptr(handler)->handle(expr->as()); + return; + case ExprType::GridWelford: + ptr(handler)->handle(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -232,28 +304,42 @@ void Statement::constDispatch(T handler, const Statement* stmt) { * ptr(mutator)->mutate(this->as()); */ template -Statement* Val::mutatorDispatch(T mutator, Val* val) { +void Val::mutatorDispatch(T mutator, Val* val) { switch (*(val->getValType())) { case ValType::Scalar: switch (*(val->getDataType())) { case DataType::Bool: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case DataType::Double: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case DataType::Int: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; default: break; } break; + case ValType::NamedScalar: + ptr(mutator)->mutate(val->as()); + return; + case ValType::IterDomain: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case ValType::TensorDomain: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; case ValType::TensorView: - return ptr(mutator)->mutate(val->as()); - case ValType::NamedScalar: - return ptr(mutator)->mutate(val->as()); + ptr(mutator)->mutate(val->as()); + return; + case ValType::Predicate: + ptr(mutator)->mutate(val->as()); + return; + case ValType::TensorIndex: + ptr(mutator)->mutate(val->as()); + return; default: break; } @@ -261,44 +347,87 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) { } template -Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { +void Expr::mutatorDispatch(T mutator, Expr* expr) { switch (*(expr->getExprType())) { - case ExprType::Split: - return ptr(mutator)->mutate(expr->as()); - case ExprType::Merge: - return ptr(mutator)->mutate(expr->as()); case ExprType::UnaryOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::BinaryOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::TernaryOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ReductionOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::WelfordOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::BroadcastOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; + + case ExprType::Split: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::Merge: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::TransposeOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ShiftOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::GatherOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ViewOp: - return ptr(mutator)->mutate(expr->as()); + ptr(mutator)->mutate(expr->as()); + return; + + case ExprType::Allocate: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::Sync: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::InitMagicZero: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::UpdateMagicZero: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::ForLoop: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::IfThenElse: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::GridReduction: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::GridBroadcast: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::GridWelford: + ptr(mutator)->mutate(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } } template -Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) { +void Statement::mutatorDispatch(T mutator, Statement* stmt) { if (stmt->isVal()) { - return ptr(mutator)->mutate(stmt->as()); + ptr(mutator)->mutate(stmt->as()); + return; } if (stmt->isExpr()) { - return ptr(mutator)->mutate(stmt->as()); + ptr(mutator)->mutate(stmt->as()); + return; } TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!"); } @@ -308,11 +437,11 @@ Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) { * classes. Actual visitors/mutators should inhereit from these classes and call * ->dispatch(this) to avoid needing an explicit instantiation. */ -template void Statement::dispatch(OptOutDispatch, Statement*); +template void Statement::dispatch(OptOutDispatch&, Statement*); template void Statement::dispatch(OptOutDispatch*, Statement*); -template void Val::dispatch(OptOutDispatch, Val*); +template void Val::dispatch(OptOutDispatch&, Val*); template void Val::dispatch(OptOutDispatch*, Val*); -template void Expr::dispatch(OptOutDispatch, Expr*); +template void Expr::dispatch(OptOutDispatch&, Expr*); template void Expr::dispatch(OptOutDispatch*, Expr*); template void Statement::dispatch(OptInDispatch, Statement*); @@ -322,33 +451,26 @@ template void Val::dispatch(OptInDispatch*, Val*); template void Expr::dispatch(OptInDispatch, Expr*); template void Expr::dispatch(OptInDispatch*, Expr*); -template void Statement::constDispatch(OptOutConstDispatch, const Statement*); +template void Statement::constDispatch(OptOutConstDispatch&, const Statement*); template void Statement::constDispatch(OptOutConstDispatch*, const Statement*); -template void Val::constDispatch(OptOutConstDispatch, const Val*); +template void Val::constDispatch(OptOutConstDispatch&, const Val*); template void Val::constDispatch(OptOutConstDispatch*, const Val*); -template void Expr::constDispatch(OptOutConstDispatch, const Expr*); +template void Expr::constDispatch(OptOutConstDispatch&, const Expr*); template void Expr::constDispatch(OptOutConstDispatch*, const Expr*); -template void Statement::constDispatch(OptInConstDispatch, const Statement*); +template void Statement::constDispatch(OptInConstDispatch&, const Statement*); template void Statement::constDispatch(OptInConstDispatch*, const Statement*); -template void Val::constDispatch(OptInConstDispatch, const Val*); +template void Val::constDispatch(OptInConstDispatch&, const Val*); template void Val::constDispatch(OptInConstDispatch*, const Val*); -template void Expr::constDispatch(OptInConstDispatch, const Expr*); +template void Expr::constDispatch(OptInConstDispatch&, const Expr*); template void Expr::constDispatch(OptInConstDispatch*, const Expr*); -template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*); -template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*); -template Statement* Val::mutatorDispatch(OptOutMutator, Val*); -template Statement* Val::mutatorDispatch(OptOutMutator*, Val*); -template Statement* Expr::mutatorDispatch(OptOutMutator, Expr*); -template Statement* Expr::mutatorDispatch(OptOutMutator*, Expr*); - -template Statement* Statement::mutatorDispatch(OptInMutator, Statement*); -template Statement* Statement::mutatorDispatch(OptInMutator*, Statement*); -template Statement* Val::mutatorDispatch(OptInMutator, Val*); -template Statement* Val::mutatorDispatch(OptInMutator*, Val*); -template Statement* Expr::mutatorDispatch(OptInMutator, Expr*); -template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*); +template void Statement::mutatorDispatch(OptOutMutator&, Statement*); +template void Statement::mutatorDispatch(OptOutMutator*, Statement*); +template void Val::mutatorDispatch(OptOutMutator&, Val*); +template void Val::mutatorDispatch(OptOutMutator*, Val*); +template void Expr::mutatorDispatch(OptOutMutator&, Expr*); +template void Expr::mutatorDispatch(OptOutMutator*, Expr*); void OptOutDispatch::handle(Statement* s) { Statement::dispatch(this, s); @@ -362,18 +484,6 @@ void OptOutDispatch::handle(Val* v) { Val::dispatch(this, v); } -void OptInDispatch::handle(Statement* s) { - Statement::dispatch(this, s); -} - -void OptInDispatch::handle(Expr* e) { - Expr::dispatch(this, e); -} - -void OptInDispatch::handle(Val* v) { - Val::dispatch(this, v); -} - void OptOutConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } @@ -386,46 +496,224 @@ void OptOutConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } -void OptInConstDispatch::handle(const Statement* s) { - Statement::constDispatch(this, s); +void OptInConstDispatch::unhandled(const Statement* stmt) { + if (stmt->isExpr()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getExprType().value(), "."); + } else if (stmt->isVal()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getValType().value(), "."); + } else { + TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type."); + } } -void OptInConstDispatch::handle(const Expr* e) { - Expr::constDispatch(this, e); +void OptInDispatch::unhandled(Statement* stmt) { + if (stmt->isExpr()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getExprType().value(), "."); + } else if (stmt->isVal()) { + TORCH_INTERNAL_ASSERT( + false, "Handle not overriden for ", stmt->getValType().value(), "."); + } else { + TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type."); + } } -void OptInConstDispatch::handle(const Val* v) { - Val::constDispatch(this, v); +// Vals +void OptOutConstDispatch::handle(const Bool* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const Double* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const Int* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const NamedScalar* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const IterDomain* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TensorDomain* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TensorView* stmt) { + unhandled(stmt); +} + +void OptOutConstDispatch::handle(const kir::Predicate* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) { + unhandled(stmt); +} + +// Exprs +void OptOutConstDispatch::handle(const UnaryOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const BinaryOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TernaryOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const ReductionOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const WelfordOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const BroadcastOp* stmt) { + unhandled(stmt); } -Statement* OptInMutator::mutate(Statement* s) { - return Statement::mutatorDispatch(this, s); +void OptOutConstDispatch::handle(const Split* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const Merge* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const TransposeOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const ShiftOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const GatherOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const ViewOp* stmt) { + unhandled(stmt); } -Statement* OptInMutator::mutate(Expr* e) { - return Expr::mutatorDispatch(this, e); +void OptOutConstDispatch::handle(const kir::Allocate* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::Sync* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::UpdateMagicZero* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::ForLoop* stmt) { + unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::IfThenElse* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::GridReduction* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::GridWelford* stmt) { + unhandled(stmt); +} + +void OptOutDispatch::unhandled(Statement*) {} -Statement* OptInMutator::mutate(Val* v) { - // If value is already mutated, return the mutation - if (mutations.find(v) != mutations.end()) - return mutations[v]; - return Val::mutatorDispatch(this, v); +// Vals +void OptOutDispatch::handle(Bool* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(Double* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(Int* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(NamedScalar* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(IterDomain* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TensorDomain* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TensorView* stmt) { + unhandled(stmt); } -Statement* OptOutMutator::mutate(Statement* s) { - return Statement::mutatorDispatch(this, s); +void OptOutDispatch::handle(kir::Predicate* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::TensorIndex* stmt) { + unhandled(stmt); } -Statement* OptOutMutator::mutate(Expr* e) { - return Expr::mutatorDispatch(this, e); +// Exprs +void OptOutDispatch::handle(UnaryOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(BinaryOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TernaryOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(ReductionOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(WelfordOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(BroadcastOp* stmt) { + unhandled(stmt); } -Statement* OptOutMutator::mutate(Val* v) { - // If value is already mutated, return the mutation - if (mutations.find(v) != mutations.end()) - return mutations[v]; - return Val::mutatorDispatch(this, v); +void OptOutDispatch::handle(Split* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(Merge* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(TransposeOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(ShiftOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(GatherOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(ViewOp* stmt) { + unhandled(stmt); +} + +void OptOutDispatch::handle(kir::Allocate* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::Sync* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::InitMagicZero* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::UpdateMagicZero* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::ForLoop* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::IfThenElse* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::GridReduction* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::GridBroadcast* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::GridWelford* stmt) { + unhandled(stmt); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index c1be76eb950..6961ebd6a15 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -1,9 +1,9 @@ #pragma once -#include - +#include #include -#include + +#include #include @@ -48,7 +48,7 @@ namespace torch { namespace jit { namespace fuser { namespace cuda { - +class IrContainer; class Fusion; // Hierarchal dispatch functions for handle @@ -60,14 +60,13 @@ class Val; class IterDomain; class TensorDomain; class TensorView; + class Bool; class Double; class Int; class NamedScalar; // Exprs -class Split; -class Merge; class UnaryOp; class BinaryOp; class TernaryOp; @@ -79,9 +78,35 @@ class ShiftOp; class GatherOp; class ViewOp; +// Exprs +class Split; +class Merge; +class TransposeOp; +class ShiftOp; +class GatherOp; +class ViewOp; + +namespace kir { +class Predicate; +class TensorIndex; + +class Allocate; +class Sync; +class ForLoop; +class IfThenElse; +class GridReduction; +class GridBroadcast; +class GridWelford; +class InitMagicZero; +class UpdateMagicZero; +} // namespace kir + // By default, all IR nodes are handled in this dispatch, and will call an empty // function on all nodes. class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { + protected: + virtual void unhandled(const Statement*) {} + public: // Hierarchal dispatch functions for handle virtual void handle(const Statement*); @@ -89,30 +114,47 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const Val*); // Vals - virtual void handle(const IterDomain*) {} - virtual void handle(const TensorDomain*) {} - virtual void handle(const TensorView*) {} - virtual void handle(const Bool*) {} - virtual void handle(const Double*) {} - virtual void handle(const Int*) {} - virtual void handle(const NamedScalar*) {} + virtual void handle(const IterDomain* stmt); + virtual void handle(const TensorDomain* stmt); + virtual void handle(const TensorView* stmt); + virtual void handle(const Bool* stmt); + virtual void handle(const Double* stmt); + virtual void handle(const Int* stmt); + virtual void handle(const NamedScalar* stmt); + + virtual void handle(const kir::Predicate*); + virtual void handle(const kir::TensorIndex*); // Exprs - virtual void handle(const Split*) {} - virtual void handle(const Merge*) {} - virtual void handle(const UnaryOp*) {} - virtual void handle(const BinaryOp*) {} - virtual void handle(const TernaryOp*) {} - virtual void handle(const ReductionOp*) {} - virtual void handle(const WelfordOp*) {} - virtual void handle(const BroadcastOp*) {} - virtual void handle(const TransposeOp*) {} - virtual void handle(const ShiftOp*) {} - virtual void handle(const GatherOp*) {} - virtual void handle(const ViewOp*) {} + virtual void handle(const UnaryOp* stmt); + virtual void handle(const BinaryOp* stmt); + virtual void handle(const TernaryOp* stmt); + virtual void handle(const ReductionOp* stmt); + virtual void handle(const WelfordOp* stmt); + virtual void handle(const BroadcastOp* stmt); + + virtual void handle(const Split* stmt); + virtual void handle(const Merge* stmt); + virtual void handle(const TransposeOp* stmt); + virtual void handle(const ShiftOp* stmt); + virtual void handle(const GatherOp* stmt); + virtual void handle(const ViewOp* stmt); + + virtual void handle(const kir::Allocate*); + virtual void handle(const kir::Sync*); + virtual void handle(const kir::InitMagicZero*); + virtual void handle(const kir::UpdateMagicZero*); + virtual void handle(const kir::ForLoop*); + virtual void handle(const kir::IfThenElse*); + virtual void handle(const kir::GridReduction*); + virtual void handle(const kir::GridBroadcast*); + virtual void handle(const kir::GridWelford*); }; class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { + protected: + virtual void unhandled(Statement*); + public: // Hierarchal dispatch functions for handle virtual void handle(Statement*); @@ -120,190 +162,88 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(Val*); // Vals - virtual void handle(IterDomain*) {} - virtual void handle(TensorDomain*) {} - virtual void handle(TensorView*) {} - virtual void handle(Bool*) {} - virtual void handle(Double*) {} - virtual void handle(Int*) {} - virtual void handle(NamedScalar*) {} + virtual void handle(Bool* stmt); + virtual void handle(Double* stmt); + virtual void handle(Int* stmt); + virtual void handle(NamedScalar* stmt); + virtual void handle(IterDomain* stmt); + virtual void handle(TensorDomain* stmt); + virtual void handle(TensorView* stmt); + + virtual void handle(kir::Predicate*); + virtual void handle(kir::TensorIndex*); // Exprs - virtual void handle(Split*) {} - virtual void handle(Merge*) {} - virtual void handle(UnaryOp*) {} - virtual void handle(BinaryOp*) {} - virtual void handle(TernaryOp*) {} - virtual void handle(ReductionOp*) {} - virtual void handle(WelfordOp*) {} - virtual void handle(BroadcastOp*) {} - virtual void handle(TransposeOp*) {} - virtual void handle(ShiftOp*) {} - virtual void handle(GatherOp*) {} - virtual void handle(ViewOp*) {} + virtual void handle(UnaryOp* stmt); + virtual void handle(BinaryOp* stmt); + virtual void handle(TernaryOp* stmt); + virtual void handle(ReductionOp* stmt); + virtual void handle(WelfordOp* stmt); + virtual void handle(BroadcastOp* stmt); + + virtual void handle(Split* stmt); + virtual void handle(Merge* stmt); + virtual void handle(TransposeOp* stmt); + virtual void handle(ShiftOp* stmt); + virtual void handle(GatherOp* stmt); + virtual void handle(ViewOp* stmt); + + virtual void handle(kir::Allocate* stmt); + virtual void handle(kir::Sync* stmt); + virtual void handle(kir::InitMagicZero* stmt); + virtual void handle(kir::UpdateMagicZero* stmt); + virtual void handle(kir::ForLoop* stmt); + virtual void handle(kir::IfThenElse* stmt); + virtual void handle(kir::GridReduction* stmt); + virtual void handle(kir::GridBroadcast* stmt); + virtual void handle(kir::GridWelford* stmt); }; -class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase { +class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch { public: - // Hierarchal dispatch functions for handle - virtual void handle(const Statement*); - virtual void handle(const Expr*); - virtual void handle(const Val*); - - // Vals - virtual void handle(const IterDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain."); - } - virtual void handle(const TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain."); - } - virtual void handle(const TensorView*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView."); - } - virtual void handle(const Bool*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool."); - } - virtual void handle(const Double*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double."); - } - virtual void handle(const Int*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int."); - } - virtual void handle(const NamedScalar*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar."); - } + using OptOutConstDispatch::handle; - // Exprs - virtual void handle(const Split*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split."); - } - virtual void handle(const Merge*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); - } - virtual void handle(const UnaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); - } - virtual void handle(const BinaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp."); - } - virtual void handle(const WelfordOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp."); - } - virtual void handle(const TernaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp."); - } - virtual void handle(const ReductionOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); - } - virtual void handle(const BroadcastOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); - } - virtual void handle(const TransposeOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp."); - } - virtual void handle(const ShiftOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp."); - } - virtual void handle(const GatherOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp."); - } - virtual void handle(const ViewOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ViewOp."); - } + protected: + virtual void unhandled(const Statement* stmt) final; }; -class TORCH_CUDA_CU_API OptInDispatch : public PolymorphicBase { +class TORCH_CUDA_CU_API OptInDispatch : public OptOutDispatch { public: - // Hierarchal dispatch functions for handle - virtual void handle(Statement* s); - virtual void handle(Expr* e); - virtual void handle(Val* v); - - // Vals - virtual void handle(IterDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain."); - } - virtual void handle(TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain."); - } - virtual void handle(TensorView*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView."); - } - virtual void handle(Bool*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool."); - } - virtual void handle(Double*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double."); - } - virtual void handle(Int*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int."); - } - virtual void handle(NamedScalar*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar."); - } + using OptOutDispatch::handle; - // Exprs - virtual void handle(Split*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split."); - } - virtual void handle(Merge*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); - } - virtual void handle(UnaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); - } - virtual void handle(BinaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp."); - } - virtual void handle(TernaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp."); - } - virtual void handle(ReductionOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); - } - virtual void handle(WelfordOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp."); - } - virtual void handle(BroadcastOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); - } - virtual void handle(TransposeOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp."); - } - virtual void handle(ShiftOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp."); - } - virtual void handle(GatherOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp."); - } - virtual void handle(ViewOp*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ViewOp."); - } + protected: + virtual void unhandled(Statement* stmt) final; }; +// Class to perform mutations on Fusion IR. Exprs can simply be redefined, but +// when mutating values they have to be registered through registerMutation so +// that exprs can detect there's been a muatation and know to modify all +// instances of that Val. This means each Val should be mutated "consistently". +// Otherwise behavior may be difficult to understand as it depends on which +// order mutate is called in. This class expects user to topologically call the +// statments of interest so inputs are called and mutated before exprs depending +// on them. +// +// Warning: TensorViews need to be treated carefully. As we don't generally +// register their mutation when their tensor domains only change. If a TV needs +// to be swapped out, it needs to be registered as a "proper" mutation like +// other vals, on top of TensorDomain being updated in the mutated TensorView. +// // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { public: // Hierarchal dispatch functions for handle - virtual Statement* mutate(Statement* s); - virtual Statement* mutate(Expr* e); - virtual Statement* mutate(Val* v); - - // We always want to dispatch through a Val, so we can capture and dispatch - // correctly members of nodes like Split->TensorDomain If we don't call the - // below function or manually cast to use mutate(Val* v) we can't intercept - // and mutate by capturing mutate(Val* v), which is what we do when we want to - // replace all instances of a value. - Statement* mutateAsVal(Val* v) { - return mutate(v); - } + virtual void mutate(Statement* s); + virtual void mutate(Expr* e); + virtual void mutate(Val* v); + + void registerMutation(Val* val, Val* mutation); - void registerMutation(Val* val, Val* mutation) { - TORCH_INTERNAL_ASSERT( - mutations.find(val) == mutations.end(), - " The same value is incorrectly being mutated twice.", - " One mutation per mutation pass is allowed."); - mutations[val] = mutation; + Val* maybeMutated(Val* val) { + if (mutations.find(val) == mutations.end()) { + return val; + } + return mutations.at(val); } std::unordered_map mutations; @@ -311,105 +251,44 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { //****Functions below defined in mutator.cpp***** // Vals - virtual Statement* mutate(IterDomain*); - virtual Statement* mutate(TensorDomain*); - virtual Statement* mutate(TensorView*); - virtual Statement* mutate(Bool*); - virtual Statement* mutate(Double*); - virtual Statement* mutate(Int*); - virtual Statement* mutate(NamedScalar*); - - // Exprs - virtual Statement* mutate(Split*); - virtual Statement* mutate(Merge*); - virtual Statement* mutate(UnaryOp*); - virtual Statement* mutate(BinaryOp*); - virtual Statement* mutate(TernaryOp*); - virtual Statement* mutate(ReductionOp*); - virtual Statement* mutate(WelfordOp*); - virtual Statement* mutate(BroadcastOp*); - virtual Statement* mutate(TransposeOp*); - virtual Statement* mutate(ShiftOp*); - virtual Statement* mutate(GatherOp*); - virtual Statement* mutate(ViewOp*); -}; - -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase { - public: - std::unordered_map mutations; - - public: - void registerMutation(Val* val, Val* mutation) { - TORCH_INTERNAL_ASSERT( - mutations.find(val) == mutations.end(), - " The same value is incorrectly being mutated twice.", - " One mutation per mutation pass is allowed."); - mutations[val] = mutation; - } - - // Hierarchal dispatch functions for mutate - virtual Statement* mutate(Statement*); - virtual Statement* mutate(Expr*); - virtual Statement* mutate(Val*); + virtual void mutate(Bool*); + virtual void mutate(Double*); + virtual void mutate(Int*); + virtual void mutate(NamedScalar*); + virtual void mutate(IterDomain*); + virtual void mutate(TensorDomain*); + virtual void mutate(TensorView*); - // Vals - virtual Statement* mutate(IterDomain*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for IterDomain."); - } - virtual Statement* mutate(TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorDomain."); - } - virtual Statement* mutate(TensorView*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorView."); - } - virtual Statement* mutate(Bool*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Bool."); - } - virtual Statement* mutate(Int*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Int."); - } - virtual Statement* mutate(NamedScalar*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for NamedScalar."); - } + virtual void mutate(kir::Predicate*); + virtual void mutate(kir::TensorIndex*); // Exprs - virtual Statement* mutate(Split*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Split."); - } - virtual Statement* mutate(Merge*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Merge."); - } - virtual Statement* mutate(UnaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for UnaryOp."); - } - virtual Statement* mutate(BinaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BinaryOp."); - } - virtual Statement* mutate(TernaryOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TernaryOp."); - } - virtual Statement* mutate(ReductionOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ReductionOp."); - } - virtual Statement* mutate(WelfordOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for WelfordOp."); - } - virtual Statement* mutate(BroadcastOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BroadcastOp."); - } - virtual Statement* mutate(TransposeOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TransposeOp."); - } - virtual Statement* mutate(ShiftOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ShiftOp."); - } - virtual Statement* mutate(GatherOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GatherOp."); - } - virtual Statement* mutate(ViewOp*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ViewOp."); - } + virtual void mutate(UnaryOp*); + virtual void mutate(BinaryOp*); + virtual void mutate(TernaryOp*); + virtual void mutate(ReductionOp*); + virtual void mutate(WelfordOp*); + virtual void mutate(BroadcastOp*); + + virtual void mutate(Split*); + virtual void mutate(Merge*); + virtual void mutate(TransposeOp*); + virtual void mutate(ShiftOp*); + virtual void mutate(GatherOp*); + virtual void mutate(ViewOp*); + + virtual void mutate(kir::Allocate*); + virtual void mutate(kir::Sync*); + virtual void mutate(kir::InitMagicZero*); + virtual void mutate(kir::UpdateMagicZero*); + virtual void mutate(kir::ForLoop*); + virtual void mutate(kir::IfThenElse*); + virtual void mutate(kir::GridReduction*); + virtual void mutate(kir::GridBroadcast*); + virtual void mutate(kir::GridWelford*); + + protected: + void removeExpr(IrContainer*, Expr*); }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp index 288dbb198b0..09481319569 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp @@ -1,9 +1,11 @@ -#include #include #include +#include #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -68,8 +70,8 @@ std::vector makeSortedEvaluationList(std::vector input) { //! Kernel IR utility, collects all the symbolic integers //! used in allocation nodes. void collectBufferSizes( - std::vector& into, - const std::vector& exprs) { + std::vector& into, + const std::vector& exprs) { for (auto expr : exprs) { if (auto allocate = dynamic_cast(expr)) { into.push_back(allocate->size()); @@ -82,56 +84,44 @@ void collectBufferSizes( } } -//! Kernel IR utility, collects all the kir symbolic +//! Kernel IR utility, collects all the kernel symbolic //! integers we will need at runtime, i.e. after the //! generated cuda kernel has already been compiled. //! The values are to be used for runtime logic, like //! `computeLaunchparams`. -std::vector collectRuntimeUsedIntegers( - Fusion* fusion, - GpuLower* lower) { - std::vector ret; - +std::vector collectRuntimeUsedIntegers(kir::Kernel* kernel) { + std::vector ret; + auto all_tvs = ir_utils::allTvs(kernel); // Collect extent and integer inputs - for (auto val : fusion->usedMathVals()) { - auto kir_val = lower->lowerValue(val); - if (auto kir_tv = dynamic_cast(kir_val)) { - for (auto id : kir_tv->domain()->domain()) { - ret.push_back(id->extent()); - } - } else if (val->isFusionInput()) { - if (kir_val->isA()) { - ret.push_back(kir_val); - } + for (auto tv : all_tvs) { + for (auto id : tv->domain()->domain()) { + ret.push_back(id->extent()); + } + } + for (auto inp : kernel->inputs()) { + if (inp->isA()) { + ret.push_back(inp); } } - // Collect allocation sizes: - collectBufferSizes(ret, lower->kernel()->topLevelExprs()); - + collectBufferSizes(ret, kernel->topLevelExprs()); return makeSortedEvaluationList(ret); } -//! Fusion IR utility, collects all the fusionIR symbolic -//! integers we will need at runtime, i.e. after the -//! generated cuda kernel has already been compiled. -//! The values are to be used for runtime logic, like -//! `canSchedule` in heuristic look up. + std::vector collectRuntimeUsedIntegers(Fusion* fusion) { std::vector ret; - + auto all_tvs = ir_utils::allTvs(fusion); // Collect extent and integer inputs - for (auto val : fusion->usedMathVals()) { - if (auto tv = dynamic_cast(val)) { - for (auto id : tv->domain()->domain()) { - ret.push_back(id->extent()); - } - } else if (val->isFusionInput()) { - if (val->isA()) { - ret.push_back(val); - } + for (auto tv : all_tvs) { + for (auto id : tv->domain()->domain()) { + ret.push_back(id->extent()); + } + } + for (auto inp : fusion->inputs()) { + if (inp->isA()) { + ret.push_back(inp); } } - return makeSortedEvaluationList(ret); } @@ -140,7 +130,7 @@ std::vector collectRuntimeUsedIntegers(Fusion* fusion) { template void PrecomputedIntegersBase::initializeValueList( typename IRContext::EVALUATOR_TYPE& const_evaluator, - const std::vector& sorted_value_list) { + const std::vector& sorted_value_list) { // Initialize workspace num_of_values_ = sorted_value_list.size(); defined_ = std::vector(num_of_values_, false); @@ -161,7 +151,7 @@ void PrecomputedIntegersBase::initializeValueList( template c10::optional PrecomputedIntegersBase::getMaybeValueFor( - const IR_VAL* val) { + const Val* val) { auto index = val->evaluatorIndex(); if (index < 0) { return c10::nullopt; @@ -172,6 +162,17 @@ c10::optional PrecomputedIntegersBase::getMaybeValueFor( return values_[index]; } +template +void PrecomputedIntegersBase::print() const { + std::cout << "Precomputed Integers:\n"; + for (auto i : c10::irange(symbols_.size())) { + if (defined_[i]) { + std::cout << symbols_[i]->toInlineString() << " = " << values_[i] + << std::endl; + } + } +} + template void PrecomputedIntegersBase::evaluate() { FUSER_PERF_SCOPE("PrecomputedIntegers::Evaluate"); @@ -208,10 +209,9 @@ NaiveIntegerMachine::NaiveIntegerMachine( for (auto val : precomputed_integers_.symbols_) { auto def = val->definition(); if (def) { - if (auto uop = dynamic_cast(def)) { + if (auto uop = dynamic_cast(def)) { makeUnaryOp(uop); - } else if ( - auto bop = dynamic_cast(def)) { + } else if (auto bop = dynamic_cast(def)) { makeBinaryOp(bop); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr"); @@ -234,8 +234,7 @@ void NaiveIntegerMachine::run() { } template -void NaiveIntegerMachine::makeUnaryOp( - typename IRContext::UNARY_OP_TYPE* uop) { +void NaiveIntegerMachine::makeUnaryOp(UnaryOp* uop) { int in = uop->inputs()[0]->evaluatorIndex(); int out = uop->outputs()[0]->evaluatorIndex(); TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: ", uop); @@ -249,8 +248,7 @@ void NaiveIntegerMachine::makeUnaryOp( } template -void NaiveIntegerMachine::makeBinaryOp( - typename IRContext::BINARY_OP_TYPE* bop) { +void NaiveIntegerMachine::makeBinaryOp(BinaryOp* bop) { int in0 = bop->inputs()[0]->evaluatorIndex(); int in1 = bop->inputs()[1]->evaluatorIndex(); int out = bop->outputs()[0]->evaluatorIndex(); @@ -377,11 +375,8 @@ void NaiveIntegerMachine::runBinaryOp(int index) { precomputed_integers_.defined_[dest_index] = true; } -KernelPrecomputedIntegers::KernelPrecomputedIntegers( - Fusion* fusion, - GpuLower& lower) - : lower_(&lower) { - loadSymbols(collectRuntimeUsedIntegers(fusion, lower_)); +KernelPrecomputedIntegers::KernelPrecomputedIntegers(kir::Kernel* kernel) { + loadSymbols(collectRuntimeUsedIntegers(kernel)); kir::ExpressionEvaluator evaluator; initializeValueList(evaluator, symbols()); initializeNamedScalars(); @@ -389,11 +384,11 @@ KernelPrecomputedIntegers::KernelPrecomputedIntegers( } void KernelPrecomputedIntegers::bindTensorMetaData( - kir::TensorView* tv, + TensorView* tv, const at::Tensor& at_tensor) { - std::vector> ret; + std::vector> ret; const auto root_domain = - kir::TensorDomain::noReductions(tv->domain()->rootDomain()); + TensorDomain::noReductions(tv->domain()->getRootDomain()); TORCH_INTERNAL_ASSERT( at_tensor.ndimension() == static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs do not match."); @@ -411,7 +406,7 @@ namespace { //! and returns the corresponding parallel type if a match //! is found. c10::optional getMaybeThreadSizeParallelType( - kir::NamedScalar* named_scalar) { + NamedScalar* named_scalar) { auto& var_name = named_scalar->name(); for (auto ptype : kParallelTypeThreads) { if (var_name == stringifyThreadSize(ptype)) { @@ -425,7 +420,7 @@ c10::optional getMaybeThreadSizeParallelType( void KernelPrecomputedIntegers::initializeNamedScalars() { for (auto val : symbols()) { - if (auto named_scalar = dynamic_cast(val)) { + if (auto named_scalar = dynamic_cast(val)) { auto maybe_parallel_type = getMaybeThreadSizeParallelType(named_scalar); if (maybe_parallel_type.has_value()) { auto& index_list = @@ -440,17 +435,17 @@ void KernelPrecomputedIntegers::initializeNamedScalars() { } void KernelPrecomputedIntegers::bindKernelInputs( + kir::Kernel* kernel, const at::ArrayRef& aten_inputs) { if (hasValidValues()) { invalidate(); } - auto kernel = lower_->kernel(); const auto& inputs = kernel->inputs(); for (const auto i : c10::irange(inputs.size())) { const auto input = inputs[i]; - if (auto tensor_input = dynamic_cast(input)) { + if (auto tensor_input = dynamic_cast(input)) { const auto aten_tensor = aten_inputs[i].toTensor(); bindTensorMetaData(tensor_input, aten_tensor); } else if (input->isScalar() && input->dtype() == DataType::Int) { diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.h b/torch/csrc/jit/codegen/cuda/evaluator_common.h index 0c16e2a8b04..7cbe37c602b 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.h +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.h @@ -35,18 +35,14 @@ class ExpressionEvaluator; //! Context for using generic logic on FusionIR class FusionIRContext { public: - using VAL_TYPE = Val; - using EXPR_TYPE = Expr; using TV_TYPE = TensorView; using EVALUATOR_TYPE = ExpressionEvaluator; - using BINARY_OP_TYPE = BinaryOp; - using UNARY_OP_TYPE = UnaryOp; - static BinaryOpType getOpType(BINARY_OP_TYPE* bop) { + static BinaryOpType getOpType(BinaryOp* bop) { return bop->getBinaryOpType(); } - static UnaryOpType getOpType(UNARY_OP_TYPE* uop) { + static UnaryOpType getOpType(UnaryOp* uop) { return uop->getUnaryOpType(); } }; @@ -54,19 +50,14 @@ class FusionIRContext { //! Context for using generic logic on KernelIR class KernelIRContext { public: - using VAL_TYPE = kir::Val; - using EXPR_TYPE = kir::Expr; - using TV_TYPE = kir::TensorView; using EVALUATOR_TYPE = kir::ExpressionEvaluator; - using BINARY_OP_TYPE = kir::BinaryOp; - using UNARY_OP_TYPE = kir::UnaryOp; - static BinaryOpType getOpType(BINARY_OP_TYPE* bop) { - return bop->operation(); + static BinaryOpType getOpType(BinaryOp* bop) { + return bop->getBinaryOpType(); } - static UnaryOpType getOpType(UNARY_OP_TYPE* uop) { - return uop->operation(); + static UnaryOpType getOpType(UnaryOp* uop) { + return uop->getUnaryOpType(); } }; @@ -97,10 +88,10 @@ class NaiveIntegerMachine { private: //! Convert an unary IR expr to an instruction - void makeUnaryOp(typename IRContext::UNARY_OP_TYPE* uop); + void makeUnaryOp(UnaryOp* uop); //! Convert an binary IR expr to an instruction - void makeBinaryOp(typename IRContext::BINARY_OP_TYPE* bop); + void makeBinaryOp(BinaryOp* bop); //! Create an empty instruction with all default values //! and place it at the end of the instruction buffer. @@ -169,11 +160,6 @@ class NaiveIntegerMachine { //! integers and store them in the workspace ahead of time. template class PrecomputedIntegersBase { - using IR_UNARY_OP = typename IRContext::UNARY_OP_TYPE; - using IR_BINARY_OP = typename IRContext::BINARY_OP_TYPE; - using IR_VAL = typename IRContext::VAL_TYPE; - using IR_EXPR = typename IRContext::EXPR_TYPE; - using IR_TV = typename IRContext::TV_TYPE; using INTEGER_MACHINE = NaiveIntegerMachine; public: @@ -190,7 +176,10 @@ class PrecomputedIntegersBase { //! Returns value for the given IR node if it's stored //! in the workspace and has been evaluated. - c10::optional getMaybeValueFor(const IR_VAL* val); + c10::optional getMaybeValueFor(const Val* val); + + //! Debugging helper, prints all the currently known values + void print() const; protected: //! Initialize the workspace before first use. @@ -198,7 +187,7 @@ class PrecomputedIntegersBase { //! been topologically sorted. void initializeValueList( typename IRContext::EVALUATOR_TYPE& evaluator, - const std::vector& sorted_value_list); + const std::vector& sorted_value_list); //! Bind concrete value to the given index //! if the index is valid. @@ -215,12 +204,12 @@ class PrecomputedIntegersBase { void invalidate(); //! Interface for subclasses to access symbols_ - void loadSymbols(std::vector symbols) { + void loadSymbols(std::vector symbols) { symbols_ = std::move(symbols); } //! Interface for subclasses to access symbols_ - std::vector& symbols() { + std::vector& symbols() { return symbols_; } @@ -267,7 +256,7 @@ class PrecomputedIntegersBase { std::vector values_; //! Stores the IR nodes corresponding to each index. - std::vector symbols_; + std::vector symbols_; //! An internal log to keep track of all the bindings //! used in each evaluation cycle. To be used for @@ -308,12 +297,14 @@ class KernelPrecomputedIntegers public: using ParallelExtentMap = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; - KernelPrecomputedIntegers(Fusion* fusion, GpuLower& lower); + KernelPrecomputedIntegers(kir::Kernel* kernel); //! Bind concrete values from fusion runtime inputs - void bindKernelInputs(const at::ArrayRef& aten_inputs); + void bindKernelInputs( + kir::Kernel* kernel, + const at::ArrayRef& aten_inputs); //! Bind concrete values from launch constraints void bindParallelExtents( @@ -326,7 +317,7 @@ class KernelPrecomputedIntegers void bindConcreteParallelTypeValue(ParallelType pt, int64_t value); private: - void bindTensorMetaData(kir::TensorView* tv, const at::Tensor& at_tensor); + void bindTensorMetaData(TensorView* tv, const at::Tensor& at_tensor); //! Iterate through all the named scalars corresponding //! to thread sizes and pre-group them by their parallel @@ -334,8 +325,6 @@ class KernelPrecomputedIntegers void initializeNamedScalars(); private: - GpuLower* lower_ = nullptr; - //! Contains all the named scalars correspond //! to thread size of each parallel type. std::unordered_map>, TypeHash> diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 647cf4ec0e2..5e6f2d9375e 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1,3 +1,4 @@ + #include #include @@ -8,21 +9,11 @@ #include #include #include -#include #include #include #include #include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - #include #include #include @@ -108,8 +99,6 @@ void FusionExecutor::debugCompileFusionFromStr( const std::string& name, int id, CompileOptions options) { - fusion_ = *fusion; - FusionGuard fg(&fusion_); options_ = options; if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) { @@ -126,11 +115,12 @@ void FusionExecutor::debugCompileFusionFromStr( << std::endl; } - setUsedTVs(); + lowered_ = std::make_unique(fusion); + const auto kernel = lowered_->kernel(); + fusion_ = lowered_->kernel(); fusion_id_ = id; - lowered_ = GpuLower(&fusion_); - const auto kernel = lowered_.kernel(); + setUsedTVs(); if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) { kernel->print(); @@ -155,9 +145,9 @@ void FusionExecutor::debugCompileFusionFromStr( void FusionExecutor::compileFusion( Fusion* fusion, - CompileOptions options, const at::ArrayRef& inputs, - const LaunchParams& launch_constraints) { + const LaunchParams& launch_constraints, + CompileOptions options) { FUSER_PERF_SCOPE("compileFusion"); TORCH_INTERNAL_ASSERT( @@ -175,9 +165,6 @@ void FusionExecutor::compileFusion( fusion->printMath(); } - // Clone the fusion so we can store it - fusion_ = *fusion; - FusionGuard fg(&fusion_); options_ = options; c10::DeviceGuard dg(options_.device); @@ -187,11 +174,12 @@ void FusionExecutor::compileFusion( max_device_smem = properties->sharedMemPerBlock; warp_size_ = properties->warpSize; - setUsedTVs(); + lowered_ = std::make_unique(fusion); + const auto kernel = lowered_->kernel(); + fusion_ = lowered_->kernel()->as(); fusion_id_ = ++fusion_id_counter_; - lowered_ = GpuLower(&fusion_); - const auto kernel = lowered_.kernel(); + setUsedTVs(); if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) { kernel->print(); @@ -216,7 +204,7 @@ void FusionExecutor::compileFusion( std::stringstream ss; ss << "Allocations must be based on constant integers for local memory. However, found: "; for (auto alloc : kernel_summary.dynamic_lmem_allocations) { - ss << toString(alloc->buffer(), false) << ", "; + ss << alloc->buffer()->toString() << ", "; } ss << " have dynamic allocations but are placed in local memory."; TORCH_INTERNAL_ASSERT(false, ss.str()); @@ -233,6 +221,8 @@ void FusionExecutor::compileFusion( block_size > 0, "launch param inferred block size < 0"); } + block_size_high_water_mark = + block_size.has_value() ? block_size.value() : block_size_high_water_mark; compiled_kernel_ = executor_utils::nvrtcCompile( structured_code, (kernelNamespace() + "::" + kernelName()).c_str(), @@ -245,8 +235,8 @@ void FusionExecutor::compileFusion( namespace { at::Tensor inferAndAlloc( - const kir::TensorView* tv, - const std::vector& sizes, + const TensorView* tv, + const std::vector& sizes, kir::ExpressionEvaluator& expr_eval, const CompileOptions& options, bool zero_init = false) { @@ -260,9 +250,11 @@ at::Tensor inferAndAlloc( TORCH_INTERNAL_ASSERT( inferred_val.has_value(), "Could not launch kernel as program could not infer ", - kir::toString(size), - " for the buffer ", - kir::toString(tv)); + size->toString(), + "(", + size->name(), + ") for the buffer ", + tv->toString()); inferred_sizes.push_back(inferred_val.value()); } @@ -283,19 +275,20 @@ at::Tensor inferAndAlloc( } at::Tensor inferAndAllocOutput( - const kir::TensorView* tv, + const TensorView* tv, kir::ExpressionEvaluator& expr_eval, const CompileOptions& options, bool zero_init = false) { const auto domain = tv->domain(); - const auto maybe_rfactor_domain = - domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain(); + const auto maybe_rfactor_domain = domain->hasRFactor() + ? domain->getRFactorDomain() + : domain->getRootDomain(); - std::vector sizes; + std::vector sizes; for (const auto id : maybe_rfactor_domain) { if (id->isReduction() || id->isStride() || - id->iterType() == IterType::BroadcastWithoutStride) { + id->getIterType() == IterType::BroadcastWithoutStride) { continue; } sizes.push_back(id->extent()); @@ -348,8 +341,7 @@ LaunchParams FusionExecutor::computeLaunchParams( auto data_cache = compileTimeDataCache(); - auto& lower = lowered_; - + auto lower = lowered_.get(); auto& used_tvs = getUsedTVs(); auto parallel_binding_ids_entry = executor_utils::caching::ExecutorCompileTimeEntry< @@ -364,9 +356,8 @@ LaunchParams FusionExecutor::computeLaunchParams( auto parallel_iter_extent_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::ParallelIterExtentMap>( - data_cache, [¶llel_binding_ids, &lower]() { - return executor_utils::getParallelIterExtents( - lower, parallel_binding_ids); + data_cache, [¶llel_binding_ids]() { + return executor_utils::getParallelIterExtents(parallel_binding_ids); }); auto& parallel_iter_extents = parallel_iter_extent_entry.get(); @@ -385,7 +376,7 @@ LaunchParams FusionExecutor::computeLaunchParams( executor_utils::caching::WarpPaddedParallelExtents>( data_cache, [¶llel_binding_ids, &lower]() { return executor_utils::getWarpPaddedExtentsInfo( - lower, parallel_binding_ids); + lower->kernel(), parallel_binding_ids); }); auto& warp_padded_extent_set = warp_padded_parallel_entry.get().warp_padded_extent_set; @@ -446,7 +437,9 @@ LaunchParams FusionExecutor::computeLaunchParams( auto val = expr_eval.evaluate(extent); TORCH_INTERNAL_ASSERT( val.has_value(), - "Tried to evaluate the extent of ", + "Tried to evaluate the extent, ", + extent->toInlineString(), + " for the ptype: ", p_type, " to set launch bounds but could not."); @@ -481,14 +474,15 @@ LaunchParams FusionExecutor::computeLaunchParams( expr_eval.precomputedIntegers()->evaluate(); } - const auto kernel = lowered_.kernel(); + const auto kernel = lowered_->kernel(); const auto& kernel_summary = kernel->summary(); // Calculate Dynamic Shared Memory Size // Add workspace for reduction and broadcast uint64_t reduction_broadcast_workspace = 0; const bool has_workspace = kernel_summary.has_block_reductions || - kernel_summary.has_grid_reductions || kernel_summary.has_block_broadcasts; + kernel_summary.has_grid_reductions || + kernel_summary.has_block_broadcasts || kernel_summary.has_grid_broadcasts; if (has_workspace && kernel_summary.largest_smem_data_type != DataType::Null) { // Not using nThreads here since it does not handle uninitialized value @@ -533,14 +527,14 @@ FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals( kir::ExpressionEvaluator& expr_eval) { FUSER_PERF_SCOPE("FusionExecutor::AllocGlobalVals"); GlobalBuffers global_buffers; - const auto kernel = lowered_.kernel(); - const auto& kernel_summary = lowered_.kernel()->summary(); + const auto kernel = lowered_->kernel(); + const auto& kernel_summary = lowered_->kernel()->summary(); for (auto alloc : kernel_summary.global_allocations) { TORCH_INTERNAL_ASSERT( - alloc->buffer()->isA(), + alloc->buffer()->isA(), "Cannot allocate global buffers that are not tensors."); - auto tv = alloc->buffer()->as(); - if (kernel->isOutput(tv)) { + auto tv = alloc->buffer()->as(); + if (tv->isFusionOutput()) { continue; } if (alloc->zeroInit()) { @@ -561,14 +555,14 @@ std::vector FusionExecutor::allocOutputs( kir::ExpressionEvaluator& expr_eval, const std::unordered_set& alias_indices) { FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs"); - const auto kernel = lowered_.kernel(); + const auto kernel = lowered_->kernel(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector outputs; for (const auto i : c10::irange(kernel->outputs().size())) { TORCH_INTERNAL_ASSERT( - kernel->outputs()[i]->isA(), + kernel->outputs()[i]->isA(), "Cannot allocate outputs that are not tensors."); - auto output = kernel->outputs()[i]->as(); + auto output = kernel->outputs()[i]->as(); if (alias_indices.count(i) == 0) { outputs.push_back( inferAndAllocOutput(output, expr_eval, options_, false)); @@ -581,7 +575,7 @@ std::vector FusionExecutor::allocOutputs( } void FusionExecutor::setUsedTVs() { - auto used_vals = fusion_.usedMathVals(); + auto used_vals = fusion_->usedMathVals(); auto used_tvs = ir_utils::filterByType(used_vals); used_tvs_.clear(); @@ -595,7 +589,7 @@ std::vector FusionExecutor::runFusion( const LaunchParams& launch_constraints, const c10::optional& opt_code) { FUSER_PERF_SCOPE("FusionExecutor::RunFusion"); - + TORCH_INTERNAL_ASSERT(compiled()); TORCH_INTERNAL_ASSERT( fusion_id_ > 0, "Cannot run fusion, it was not compiled."); TORCH_INTERNAL_ASSERT( @@ -607,11 +601,10 @@ std::vector FusionExecutor::runFusion( executor_entry = &executor_entry_lookup_[*opt_code]; } - FusionGuard fg(&fusion_); c10::DeviceGuard dg(options_.device); auto stream = at::cuda::getCurrentCUDAStream(); executor_utils::initializeCudaContext(); - + TORCH_INTERNAL_ASSERT(lowered_); LaunchParams launch_params; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector allocated_outputs = outputs; @@ -642,7 +635,7 @@ std::vector FusionExecutor::runFusion( } } else { TORCH_INTERNAL_ASSERT( - outputs.size() == fusion_.outputs().size(), + outputs.size() == fusion_->outputs().size(), __func__, " provided number of outputs does match fusion output"); } @@ -672,20 +665,35 @@ std::vector FusionExecutor::runFusion( // code path to take when either: // 1. no opt_code is provided or // 2. `executor_entry` is not initialized - executor_utils::validateKernelInputs(&fusion_, inputs, options_.device); + executor_utils::validateKernelInputs(fusion_, inputs, options_.device); if (!evaluator_precomputed_integers_) { evaluator_precomputed_integers_ = - std::make_unique(&fusion_, lowered_); + std::make_unique(lowered_->kernel()); } kir::ExpressionEvaluator expr_eval; - evaluator_precomputed_integers_->bindKernelInputs(inputs); + evaluator_precomputed_integers_->bindKernelInputs( + lowered_->kernel(), inputs); expr_eval.precomputedIntegers() = evaluator_precomputed_integers_.get(); launch_params = computeLaunchParams(launch_constraints, expr_eval, warp_size_); + // Recompile the kernel if the number of threads in the block has increased + if (launch_params.nThreads() > block_size_high_water_mark) { + const auto kernel = lowered_->kernel(); + const auto kernel_code = + codegen::generateCudaKernel(kernel, kernelName()); + const auto structured_code = getStructuredCode(kernel_code); + block_size_high_water_mark = launch_params.nThreads(); + compiled_kernel_ = executor_utils::nvrtcCompile( + structured_code, + (kernelNamespace() + "::" + kernelName()).c_str(), + fusion_id_, + block_size_high_water_mark); + } + if (kernel()->summary().has_cooperative_grid_reduction) { #ifndef __HIP_PLATFORM_HCC__ int num_blocks_per_SM = -1; @@ -716,16 +724,18 @@ std::vector FusionExecutor::runFusion( } executor_utils::validateVectorizedTensors( - &fusion_, inputs, outputs, lowered_, compileTimeDataCache(), expr_eval); - - auto& fusion = fusion_; + lowered_.get()->kernel(), + inputs, + outputs, + compileTimeDataCache(), + expr_eval); auto alias_indices_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::InputAliasIndices>( - compileTimeDataCache(), [&fusion]() { + compileTimeDataCache(), [&]() { return std::make_unique>>( - fusion.getInputAliasIndices()); + fusion_->getInputAliasIndices()); }); auto& alias_indices = alias_indices_entry.get(); @@ -736,9 +746,9 @@ std::vector FusionExecutor::runFusion( auto output_alias_indices_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::OutputAliasIndices>( - compileTimeDataCache(), [&fusion]() { + compileTimeDataCache(), [&]() { return std::make_unique>( - fusion.getOutputAliasIndices()); + fusion_->getOutputAliasIndices()); }); auto& output_alias_indices = output_alias_indices_entry.get(); @@ -753,7 +763,7 @@ std::vector FusionExecutor::runFusion( } else { // TODO: Update this as well; executor_utils::validateKernelOutputs( - &fusion_, allocated_outputs, options_.device); + fusion_, allocated_outputs, options_.device); } global_buffers = allocGlobalVals(expr_eval); @@ -802,7 +812,7 @@ std::vector FusionExecutor::runFusion( kernel_arguments.push(inputs); kernel_arguments.push(allocated_outputs); kernel_arguments.push(global_buffers.buffers); - if (lowered_.kernel()->summary().is_stochastic) { + if (lowered_->kernel()->summary().is_stochastic) { kernel_arguments.appendPhiloxRNGSeed(rand_offset); } } diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 523f2aa0e4b..40accbfb520 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -35,9 +35,9 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { void compileFusion( Fusion* fusion, - CompileOptions options = CompileOptions(), const at::ArrayRef& inputs = {}, - const LaunchParams& launch_constraints = LaunchParams()); + const LaunchParams& launch_constraints = LaunchParams(), + CompileOptions options = CompileOptions()); std::vector runFusion( const at::ArrayRef& inputs, @@ -55,7 +55,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { // function to query whether a `FusionExecutor` has a compiled kernel to // execute bool compiled() const { - return fusion_id_ != -1; + return fusion_id_ != -1 && lowered_; }; void evictCache(size_t cache_id) { @@ -85,7 +85,8 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { executor_utils::caching::ExecutorCompileTimeInfoCache; kir::Kernel* kernel() const { - return lowered_.kernel(); + TORCH_INTERNAL_ASSERT(lowered_); + return lowered_->kernel(); } //! Internal knob used for debugging/profiling only @@ -178,8 +179,6 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { } private: - Fusion fusion_; - CompileOptions options_; size_t max_device_smem = std::numeric_limits().max(); int warp_size_ = 0; @@ -192,7 +191,13 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { int fusion_id_ = -1; static int fusion_id_counter_; - GpuLower lowered_; + std::unique_ptr lowered_; + // Copy of lowered_->kernel() + Fusion* fusion_ = nullptr; + + // Track the block size this kernel was compiled with. If the block size + // increases, recompile to adjust maxregister count. + int64_t block_size_high_water_mark = 1; // lookup table to take short cut to retrieve recorded information in order to // launch kernels without re-inference parameters. diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp index 968570c1086..883fae207c5 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp @@ -1,4 +1,3 @@ -#include #include // Extract size and strides @@ -65,7 +64,7 @@ std::unique_ptr getTensorArg(int nDims) { false, "Tried to generate a tensor to run a generated kernel with ", nDims, - " dimensions, however it must be a size 0 to 8 dimensional tensor."); + " dimensions, however only 0 to 8 dimensional tensor are supported."); } return nullptr; } @@ -98,8 +97,6 @@ std::unique_ptr getTensorArg( } } -} // namespace - std::unique_ptr getTensorArg( c10::ScalarType dtype, int nDims, @@ -117,20 +114,73 @@ std::unique_ptr getTensorArg( return nullptr; } +} // namespace + // Push a tensor to the arguments void KernelArgumentHolder::push(const at::Tensor& tensor) { changed_ = true; - int nDims = tensor.ndimension(); - - c10::ScalarType dtype = tensor.scalar_type(); - std::unique_ptr tensor_arg = - getTensorArg(dtype, nDims, index_mode_); - tensor_arg->setPointer(tensor.data_ptr()); - for (const auto i : c10::irange(nDims)) { - tensor_arg->setSize(i, tensor.sizes()[i]); - tensor_arg->setStride(i, tensor.strides()[i]); + if (is_cpu_scalar(tensor)) { + switch (tensor.scalar_type()) { + case c10::ScalarType::Double: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Float: + arguments_.push_back( + std::make_unique>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Half: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::BFloat16: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Bool: + arguments_.push_back( + std::make_unique>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Long: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + case c10::ScalarType::Int: + arguments_.push_back( + std::make_unique< + CpuScalarTensorArg>>( + tensor.data_ptr()[0])); + break; + default: + TORCH_CHECK( + false, + "Dtype: ", + tensor.scalar_type(), + " not currently supported in code generated kernels."); + } + } else { + int nDims = tensor.ndimension(); + + c10::ScalarType dtype = tensor.scalar_type(); + std::unique_ptr tensor_arg = + getTensorArg(dtype, nDims, index_mode_); + tensor_arg->setPointer(tensor.data_ptr()); + for (const auto i : c10::irange(nDims)) { + tensor_arg->setSize(i, tensor.sizes()[i]); + tensor_arg->setStride(i, tensor.strides()[i]); + } + arguments_.push_back(std::move(tensor_arg)); } - arguments_.push_back(std::move(tensor_arg)); } // Push a scalar or integer to the arguments diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h index d306683c43d..d457a69adb2 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h @@ -33,6 +33,7 @@ struct TensorArgCodegen { } }; +// 0-Dim GPU based tensor template struct TensorArgCodegen { T& operator[](nvfuser_index_t ind) { @@ -51,6 +52,17 @@ struct TensorArgCodegen { } }; +// Specialization for 0-dim case that's easy to pass in a CPU based tensor +// without memcpy +template +struct CpuScalarTensorCodegen { + T& operator[](int) { + return data; + }; + + T data; +}; + struct ArgAbstract { virtual ~ArgAbstract() = default; virtual void* arg() = 0; @@ -67,7 +79,7 @@ struct PhiloxCudaStateArg : public ArgAbstract { struct LongArg : public ArgAbstract { int64_t val_; - explicit LongArg(int64_t _val) : val_(_val){}; + explicit LongArg(int64_t _val) : val_(_val) {} // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) void* arg() { return &val_; @@ -76,7 +88,7 @@ struct LongArg : public ArgAbstract { struct DoubleArg : public ArgAbstract { double val_; - explicit DoubleArg(double _val) : val_(_val){}; + explicit DoubleArg(double _val) : val_(_val) {} // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) void* arg() { return &val_; @@ -85,7 +97,7 @@ struct DoubleArg : public ArgAbstract { struct BoolArg : public ArgAbstract { bool val_; - explicit BoolArg(bool _val) : val_(_val){}; + explicit BoolArg(bool _val) : val_(_val) {} // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) void* arg() { return &val_; @@ -119,9 +131,20 @@ struct TensorArg : public TensorArgAbstract { } }; -std::unique_ptr getTensorArg( - c10::ScalarType dtype, - int nDims); +template +struct CpuScalarTensorArg : public ArgAbstract { + CPU_TENSOR_TYPE instance_; + + CpuScalarTensorArg() = delete; + + explicit CpuScalarTensorArg(decltype(CPU_TENSOR_TYPE::data) _data) { + instance_.data = _data; + } + + void* arg() override { + return &instance_; + } +}; class KernelArgumentHolder { public: diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 13cdc29099e..5323036e5df 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -110,6 +109,16 @@ bool validateKernelArgTensor( return false; } + if (is_cpu_scalar(arg) && !param->as()->isCpuScalar()) { + msg << "Argument is CPU Scalar Tensor, but parameter is not.\n"; + return false; + } + + if (!is_cpu_scalar(arg) && !arg.is_cuda()) { + msg << "Argumnet is a CPU tensor which is not supported in fusions.\n"; + return false; + } + // Check the rank of the tensors. size_t arg_dim = arg.dim(); // Note: This requires current Fusion to be active. @@ -126,7 +135,7 @@ bool validateKernelArgTensor( return false; } - if (arg.device() != device) { + if (!is_cpu_scalar(arg) && arg.device() != device) { msg << "Argument is on device that is not compiled for." << "\n"; return false; @@ -339,6 +348,8 @@ void validateKernelOutputs( !mismatch, "Found one or more invalid arguments: ", msg.str()); } +namespace { + bool canVectorize(const IValue& aten_val, int word_size) { if (!aten_val.isTensor()) { return false; @@ -371,16 +382,18 @@ bool canVectorize(const IValue& aten_val, int word_size) { return true; } +// Returns true if a TV can be used with ParallelType::Vectorize. When +// input or output tensors are involved, the other version of +// canVectorize is used. bool canVectorize( - TensorView* fusion_tv, + TensorView* tv, int word_size, - GpuLower& lower, kir::ExpressionEvaluator& expr_eval) { IterDomain* last_root_dim = nullptr; - // TODO: Should this be rfactor instead of root?? - for (size_t i = fusion_tv->getRootDomain().size(); i > 0; i--) { - auto r_id = fusion_tv->getRootDomain()[i - 1]; - if (r_id->isReduction() || r_id->isBroadcast()) { + for (size_t i = tv->getRootDomain().size(); i > 0; i--) { + auto r_id = tv->getRootDomain()[i - 1]; + if (r_id->isReduction() || r_id->isTrivialReduction() || + r_id->isBroadcast()) { continue; } last_root_dim = r_id; @@ -391,8 +404,7 @@ bool canVectorize( return false; } - auto last_dim_size = - expr_eval.evaluate(lower.lowerValue(last_root_dim->extent())); + auto last_dim_size = expr_eval.evaluate(last_root_dim->extent()); if (!last_dim_size.has_value()) { return false; @@ -405,8 +417,6 @@ bool canVectorize( return true; } -namespace { - // Check if there's any split that is non-divisible and vectorized. If // found, Vectorize is illegal. void validateVectorizedSplits( @@ -418,12 +428,12 @@ void validateVectorizedSplits( TORCH_INTERNAL_ASSERT( input_extent.has_value(), "Could not check if a split with vectorization is divisible because the extent, ", - kir::toString(extent_factor.first), + extent_factor.first->toString(), ", is not possible to evaluate."); TORCH_INTERNAL_ASSERT( input_extent.has_value(), "Could not check if a split with vectorization is divisible because the split factor, ", - kir::toString(extent_factor.second), + extent_factor.second->toString(), ", is not possible to evaluate."); TORCH_INTERNAL_ASSERT( input_extent.value() % split_factor.value() == 0, @@ -435,16 +445,144 @@ void validateVectorizedSplits( } } +//! Returns the position information of vectorized input/output tensors +//! in the given fusion. +std::unique_ptr getVectorizedTensorValidationInfo( + Fusion* fusion) { + auto vectorized_tensor_info_ptr = + std::make_unique(); + auto& tv_to_vector_word_size = + vectorized_tensor_info_ptr->tv_to_vector_word_size; + auto& global_inp_misaligned_tv = + vectorized_tensor_info_ptr->global_inp_misaligned_tv; + auto& global_out_misaligned_tv = + vectorized_tensor_info_ptr->global_out_misaligned_tv; + + kir::ExpressionEvaluator expr_eval; + + // Find all vectorized tensors and their word size + for (auto expr : fusion->exprs()) { + if (!expr->isA() || + expr->as()->getUnaryOpType() != UnaryOpType::Set) { + continue; + } + auto uop = expr->as(); + if (!uop->out()->isA() || !uop->in()->isA()) { + continue; + } + auto out_tv = uop->out()->as(); + auto in_tv = uop->in()->as(); + IterDomain* vector_dim = nullptr; + for (auto id : out_tv->domain()->domain()) { + if (id->getParallelType() == ParallelType::Vectorize || + id->getParallelType() == ParallelType::MisalignedVectorize) { + TORCH_INTERNAL_ASSERT( + vector_dim == nullptr, + "Found multiple vectorized dimensions on tensor ", + out_tv); + vector_dim = id; + } + } + if (vector_dim == nullptr) { + continue; + } + auto vector_word_size = expr_eval.evaluate(vector_dim->extent()); + TORCH_INTERNAL_ASSERT( + vector_word_size.has_value(), + "Non constant vector dimension found in ", + out_tv); + + // The expression here must be a UnaryOp::Set, so checking either of the + // input or output tensor should be sufficient. When the output is a + // fusion output, check the tensor as its size information is available + // without using the expression evaluator. + auto tv_to_verify = out_tv->isFusionOutput() ? out_tv : in_tv; + tv_to_vector_word_size[tv_to_verify] = vector_word_size.value(); + + if (vector_dim->getParallelType() == ParallelType::MisalignedVectorize) { + TORCH_INTERNAL_ASSERT( + in_tv->isFusionInput() || out_tv->isFusionOutput(), + "MisalignedVectorize is assumed to be used with either input or output tensor"); + if (out_tv->getMemoryType() == MemoryType::Global && + in_tv->getMemoryType() == MemoryType::Local) { + global_out_misaligned_tv.insert(out_tv); + } else if ( + in_tv->getMemoryType() == MemoryType::Global && + out_tv->getMemoryType() == MemoryType::Local) { + global_inp_misaligned_tv.insert(in_tv); + } else { + TORCH_INTERNAL_ASSERT( + false, + "Unsupported memory configuration for misaligned vectorization."); + } + } + } + + // Check striding information on input and outputs as well as size information + // of all + auto& inp_misaligned_tensors_pos = + vectorized_tensor_info_ptr->inp_misaligned_tensors_pos; + auto& out_misaligned_tensors_pos = + vectorized_tensor_info_ptr->out_misaligned_tensors_pos; + auto& inp_pos_to_word_size_map_to_verify = + vectorized_tensor_info_ptr->inp_pos_to_word_size_map_to_verify; + auto& out_pos_to_word_size_map_to_verify = + vectorized_tensor_info_ptr->out_pos_to_word_size_map_to_verify; + auto& intermediate_tv_to_word_size_map_to_verify = + vectorized_tensor_info_ptr->intermediate_tv_to_word_size_map_to_verify; + + for (auto entry : tv_to_vector_word_size) { + auto tv = entry.first; + auto word_size = entry.second; + if (tv->isFusionInput()) { + auto inp_it = + std::find(fusion->inputs().begin(), fusion->inputs().end(), tv); + TORCH_INTERNAL_ASSERT( + inp_it != fusion->inputs().end(), + "Could not find ", + tv, + " in fusion inputs."); + auto inp_pos = std::distance(fusion->inputs().begin(), inp_it); + + if (global_inp_misaligned_tv.find(tv) != global_inp_misaligned_tv.end()) { + inp_misaligned_tensors_pos.emplace_back(inp_pos); + } else { + // Shouldn't visit same pos twice here, assert ? + inp_pos_to_word_size_map_to_verify[inp_pos] = word_size; + } + } else if (tv->isFusionOutput()) { + auto out_it = + std::find(fusion->outputs().begin(), fusion->outputs().end(), tv); + TORCH_INTERNAL_ASSERT( + out_it != fusion->outputs().end(), + "Could not find ", + tv, + " in provided fusion outputs."); + auto out_pos = std::distance(fusion->outputs().begin(), out_it); + + if (global_out_misaligned_tv.find(tv) != global_out_misaligned_tv.end()) { + out_misaligned_tensors_pos.emplace_back(out_pos); + } else { + out_pos_to_word_size_map_to_verify[out_pos] = word_size; + } + } else { + // Intermediate tensors. Note that this must be Vectorize as + // MisalignedVectorize is only supported for inputs and outputs. + intermediate_tv_to_word_size_map_to_verify[tv] = word_size; + } + } + + return vectorized_tensor_info_ptr; +} } // namespace // Misaligned vectorization check. Currently misaligned vectorization is limited // to global-register and register-global load/store patterns. However, this // could be improved to include shared memory. void validateVectorizedTensors( - Fusion* fusion, + kir::Kernel* kernel, const at::ArrayRef& inputs, const std::vector& outputs, - GpuLower& lower, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval) { FUSER_PERF_SCOPE("FusionExecutor::validateVectorizedTensors"); @@ -452,9 +590,8 @@ void validateVectorizedTensors( auto tensor_vectorization_validation_entry = executor_utils::caching::ExecutorCompileTimeEntry< executor_utils::caching::VectorizedTensorValidation>( - data_cache, [fusion, &lower]() { - return executor_utils::getVectorizedTensorValidationInfo( - fusion, lower); + data_cache, [kernel]() { + return executor_utils::getVectorizedTensorValidationInfo(kernel); }); // Validate all the canVectorizes: @@ -463,7 +600,7 @@ void validateVectorizedTensors( TORCH_INTERNAL_ASSERT( canVectorize(inputs[it.first], it.second), "Error vectorizing, ", - fusion->inputs()[it.first], + kernel->inputs()[it.first], " as input provided does not allowed vectorization by word size, ", it.second); } @@ -474,12 +611,24 @@ void validateVectorizedTensors( TORCH_INTERNAL_ASSERT( canVectorize(outputs[it.first], it.second), "Error vectorizing, ", - fusion->outputs()[it.first], + kernel->outputs()[it.first], " as output provided does not allowed vectorization by word size, ", it.second); } } + for (auto it : tensor_vectorization_validation_entry.get() + .intermediate_tv_to_word_size_map_to_verify) { + auto tv = it.first; + auto vec_width = it.second; + TORCH_INTERNAL_ASSERT( + canVectorize(tv, vec_width, expr_eval), + "Error vectorizing, ", + tv->toString(), + " as the extent of the vectorized axis does not allowed vectorization by word size, ", + vec_width); + } + std::vector inp_misaligned_tensors; std::vector out_misaligned_tensors; @@ -511,7 +660,7 @@ void validateVectorizedTensors( out_misaligned_tensors), "All global tensors must have the same stride for misaligned vectorization."); - validateVectorizedSplits(lower.kernel(), expr_eval); + validateVectorizedSplits(kernel, expr_eval); } kir::ExpressionEvaluator bindKernelInputs( @@ -530,7 +679,7 @@ kir::ExpressionEvaluator bindKernelInputs( for (const auto i : c10::irange(inputs.size())) { const auto input = inputs[i]; - if (auto tensor_input = dynamic_cast(input)) { + if (auto tensor_input = dynamic_cast(input)) { TORCH_INTERNAL_ASSERT( aten_inputs[i].isTensor(), "Something went wrong configuring launch. Inputs no longer match at index:", @@ -538,7 +687,7 @@ kir::ExpressionEvaluator bindKernelInputs( const auto aten_tensor = aten_inputs[i].toTensor(); const auto root_domain = - kir::TensorDomain::noReductions(tensor_input->domain()->rootDomain()); + TensorDomain::noReductions(tensor_input->domain()->getRootDomain()); TORCH_INTERNAL_ASSERT( aten_tensor.ndimension() == static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs no longer match."); @@ -553,7 +702,7 @@ kir::ExpressionEvaluator bindKernelInputs( TORCH_CHECK( *prev_value == value, "Attempting to bind ", - kir::toString(extent), + extent->toString(), " to ", value, "but it's already set to ", @@ -561,7 +710,7 @@ kir::ExpressionEvaluator bindKernelInputs( should_bind = false; } } - if (should_bind && !extent->isConst()) { + if (should_bind && !extent->isConstScalar()) { expr_eval.bind(extent, value); } } @@ -697,24 +846,19 @@ NvrtcFunction nvrtcCompile( "--std=c++14", compute.c_str(), "-default-device"}; #endif - const char* disable_fastmath = getenv("PYTORCH_NVFUSER_DISABLE_FASTMATH"); - if (!disable_fastmath || (atoi(disable_fastmath) == 0)) { - args.push_back("--use_fast_math"); - } else { - TORCH_WARN_ONCE( - "fast math disabled in nvfuser, try set `PYTORCH_NVFUSER_DISABLE_FASTMATH=0`"); - } - const char* disable_fma = getenv("PYTORCH_NVFUSER_DISABLE_FMA"); - // int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0; - if (disable_fma && atoi(disable_fma)) { #ifdef __HIP_PLATFORM_HCC__ + if (disable_fma && atoi(disable_fma)) { TORCH_WARN_ONCE( "PYTORCH_CUDA_FUSER_DISABLE_FMA is not supported on ROCm, ignoring"); + } #else + if (disable_fma && atoi(disable_fma)) { args.push_back("--fmad=false"); -#endif + } else { + args.push_back("--fmad=true"); } +#endif #ifndef NDEBUG // Add line info to generated kernels @@ -1037,7 +1181,7 @@ template class ExecutorCompileTimeEntry; } // namespace caching std::vector getParallelBindingsIterDomains( - GpuLower& lower, + GpuLower* lower, const std::vector& used_tvs) { std::vector parallel_ids; for (auto tv : used_tvs) { @@ -1047,7 +1191,7 @@ std::vector getParallelBindingsIterDomains( // Want to keep the broadcast dimensions if they are not resolved // TODO: piping down the parallel dimension map here would // be helpful - auto& parallel_map = lower.caParallelMap(); + auto& parallel_map = lower->caParallelMap(); if (parallel_map.getConcreteMappedID(id) == id) { parallel_ids.push_back(id); } @@ -1062,39 +1206,41 @@ std::vector getParallelBindingsIterDomains( return parallel_ids; } +namespace { + void insertParallelExtent( - GpuLower& lower, IterDomain* binding_id, const std::unique_ptr& parallel_iter_extents_ptr) { - auto kir_extent = lower.lowerValue(binding_id->extent()); + auto extent = binding_id->extent(); const auto it = parallel_iter_extents_ptr->find(binding_id->getParallelType()); if (it != parallel_iter_extents_ptr->end()) { - it->second.push_back(kir_extent); + it->second.push_back(extent); } else { parallel_iter_extents_ptr->operator[](binding_id->getParallelType()) = { - kir_extent}; + extent}; } } +} // namespace + std::unique_ptr getParallelIterExtents( - GpuLower& lower, std::vector& parallel_binding_ids) { auto parallel_iter_extents_ptr = std::make_unique(); for (auto id : parallel_binding_ids) { - insertParallelExtent(lower, id, parallel_iter_extents_ptr); + insertParallelExtent(id, parallel_iter_extents_ptr); } return parallel_iter_extents_ptr; } std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower& lower, + GpuLower* lower, std::vector& parallel_binding_ids) { auto parallel_iter_extents_ptr = std::make_unique(); - auto& parallel_map = lower.caParallelMap(); + auto& parallel_map = lower->caParallelMap(); std::vector mapped; - bool is_tidx_warp_padded = lower.getWarpPaddedParallelInfo().is_tidx_padded; + bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded; for (auto id : parallel_binding_ids) { if (std::any_of( @@ -1109,7 +1255,7 @@ std::unique_ptr getSimplifiedParallelIterExtents( } insertParallelExtent( - lower, parallel_map.getConcreteMappedID(id), parallel_iter_extents_ptr); + parallel_map.getConcreteMappedID(id), parallel_iter_extents_ptr); mapped.push_back(id); } @@ -1117,7 +1263,7 @@ std::unique_ptr getSimplifiedParallelIterExtents( } std::unique_ptr getWarpPaddedExtentsInfo( - GpuLower& lower, + kir::Kernel* kernel, std::vector& parallel_binding_ids) { auto warp_padded_extent_info_ptr = std::make_unique(); @@ -1125,7 +1271,6 @@ std::unique_ptr getWarpPaddedExtentsInfo( warp_padded_extent_info_ptr->warp_padded_extent_set; auto& warp_padded_constant = warp_padded_extent_info_ptr->warp_padded_constant; - auto kernel = lower.kernel(); bool has_warp_reduction = kernel->getWarpPaddedParallelInfo().has_warp_reduction; @@ -1135,11 +1280,11 @@ std::unique_ptr getWarpPaddedExtentsInfo( if (has_warp_reduction) { if (id->hasPaddingToMultipleOfWarp() || kernel->isParallelTypePadded(id->getParallelType())) { - auto kir_extent = lower.lowerValue(id->extent()); - warp_padded_extent_set.insert(kir_extent); + auto extent = id->extent(); + warp_padded_extent_set.insert(extent); auto padded_value = id->getMaybeSizeAfterPadding(); if (padded_value.has_value()) { - warp_padded_constant[kir_extent] = padded_value.value(); + warp_padded_constant[extent] = padded_value.value(); } } } @@ -1147,122 +1292,6 @@ std::unique_ptr getWarpPaddedExtentsInfo( return warp_padded_extent_info_ptr; } -std::unique_ptr getVectorizedTensorValidationInfo( - Fusion* fusion, - GpuLower& lower) { - auto vectorized_tensor_info_ptr = - std::make_unique(); - auto& tv_to_vector_word_size = - vectorized_tensor_info_ptr->tv_to_vector_word_size; - auto& global_inp_misaligned_tv = - vectorized_tensor_info_ptr->global_inp_misaligned_tv; - auto& global_out_misaligned_tv = - vectorized_tensor_info_ptr->global_out_misaligned_tv; - - kir::ExpressionEvaluator expr_eval; - - // Find all vectorized tensors and their word size - for (auto expr : fusion->exprs()) { - if (!expr->isA() || - expr->as()->getUnaryOpType() != UnaryOpType::Set) { - continue; - } - auto uop = expr->as(); - if (!uop->out()->isA() || !uop->in()->isA()) { - continue; - } - auto out_tv = uop->out()->as(); - auto in_tv = uop->in()->as(); - IterDomain* vector_dim = nullptr; - for (auto id : out_tv->domain()->domain()) { - if (id->getParallelType() == ParallelType::Vectorize || - id->getParallelType() == ParallelType::MisalignedVectorize) { - TORCH_INTERNAL_ASSERT( - vector_dim == nullptr, - "Found multiple vectorized dimensions on tensor ", - out_tv); - vector_dim = id; - } - } - if (vector_dim == nullptr) { - continue; - } - auto vector_word_size = - expr_eval.evaluate(lower.lowerValue(vector_dim->extent())); - TORCH_INTERNAL_ASSERT( - vector_word_size.has_value(), - "Non constant vector dimension found in ", - out_tv); - tv_to_vector_word_size[out_tv] = vector_word_size.value(); - tv_to_vector_word_size[in_tv] = vector_word_size.value(); - - if (vector_dim->getParallelType() == ParallelType::MisalignedVectorize) { - if (out_tv->getMemoryType() == MemoryType::Global && - in_tv->getMemoryType() == MemoryType::Local) { - global_out_misaligned_tv.insert(out_tv); - } else if ( - in_tv->getMemoryType() == MemoryType::Global && - out_tv->getMemoryType() == MemoryType::Local) { - global_inp_misaligned_tv.insert(in_tv); - } else { - TORCH_INTERNAL_ASSERT( - false, - "Unsupported memory configuration for misaligned vectorization."); - } - } - } - - // Check striding information on input and outputs as well as size information - // of all - auto& inp_misaligned_tensors_pos = - vectorized_tensor_info_ptr->inp_misaligned_tensors_pos; - auto& out_misaligned_tensors_pos = - vectorized_tensor_info_ptr->out_misaligned_tensors_pos; - auto& inp_pos_to_word_size_map_to_verify = - vectorized_tensor_info_ptr->inp_pos_to_word_size_map_to_verify; - auto& out_pos_to_word_size_map_to_verify = - vectorized_tensor_info_ptr->out_pos_to_word_size_map_to_verify; - - for (auto entry : tv_to_vector_word_size) { - auto tv = entry.first; - auto word_size = entry.second; - if (tv->isFusionInput()) { - auto inp_it = - std::find(fusion->inputs().begin(), fusion->inputs().end(), tv); - TORCH_INTERNAL_ASSERT( - inp_it != fusion->inputs().end(), - "Could not find ", - tv, - " in fusion inputs."); - auto inp_pos = std::distance(fusion->inputs().begin(), inp_it); - - if (global_inp_misaligned_tv.find(tv) != global_inp_misaligned_tv.end()) { - inp_misaligned_tensors_pos.emplace_back(inp_pos); - } else { - // Shouldn't visit same pos twice here, assert ? - inp_pos_to_word_size_map_to_verify[inp_pos] = word_size; - } - } else if (tv->isFusionOutput()) { - auto out_it = - std::find(fusion->outputs().begin(), fusion->outputs().end(), tv); - TORCH_INTERNAL_ASSERT( - out_it != fusion->outputs().end(), - "Could not find ", - tv, - " in provided fusion outputs."); - auto out_pos = std::distance(fusion->outputs().begin(), out_it); - - if (global_out_misaligned_tv.find(tv) != global_out_misaligned_tv.end()) { - out_misaligned_tensors_pos.emplace_back(out_pos); - } else { - out_pos_to_word_size_map_to_verify[out_pos] = word_size; - } - } - } - - return vectorized_tensor_info_ptr; -} - } // namespace executor_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.h b/torch/csrc/jit/codegen/cuda/executor_utils.h index d851be48991..93deec6343f 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.h +++ b/torch/csrc/jit/codegen/cuda/executor_utils.h @@ -28,28 +28,16 @@ namespace executor_utils { // Include all the functions we might need in generated code std::string kernelPreamble(); -// TODO(kir): rewrite in terms of Kernel inputs void validateKernelInputs( Fusion* fusion, const at::ArrayRef& inputs, const c10::Device& device); -// TODO(kir): rewrite in terms of Kernel outputs void validateKernelOutputs( Fusion* fusion, const std::vector& outputs, const c10::Device& device); -// Returns if vectorizing the aten value by word size is possible -bool canVectorize(const IValue& aten_val, int word_size); - -// Returns if vectorizing the aten value by word size is possible -bool canVectorize( - TensorView* fusion_tv, - int word_size, - GpuLower& lower, - kir::ExpressionEvaluator& expr_eval); - //! Bind kernel input values to runtime values kir::ExpressionEvaluator bindKernelInputs( const at::ArrayRef& aten_inputs, @@ -112,7 +100,7 @@ class ParallelBindingIterDomains { class ParallelIterExtentMap { public: using DataType = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; static const CompileTimeEntryType EntryType = CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP; }; @@ -133,7 +121,7 @@ class ParallelIterExtentMap { class SimplifiedParallelIterExtentMap { public: using DataType = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; static const CompileTimeEntryType EntryType = CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP; }; @@ -141,8 +129,8 @@ class SimplifiedParallelIterExtentMap { //! WarpPaddedExtentsInfo: //! Auxiliary data type for entry class WarpPaddedParallelExtents struct WarpPaddedExtentsInfo { - std::unordered_set warp_padded_extent_set; - std::unordered_map warp_padded_constant; + std::unordered_set warp_padded_extent_set; + std::unordered_map warp_padded_constant; }; //! Compile-time info to be cached in each FusionExecutor: @@ -166,6 +154,8 @@ struct VectorizedTensorInfo { std::vector out_misaligned_tensors_pos; std::unordered_map inp_pos_to_word_size_map_to_verify; std::unordered_map out_pos_to_word_size_map_to_verify; + std::unordered_map + intermediate_tv_to_word_size_map_to_verify; }; //! Compile-time info to be cached in each FusionExecutor: @@ -284,42 +274,33 @@ class ExecutorCompileTimeEntry { //! Returns the vector of tensorviews that will be used to bind parallel //! dimensions. std::vector getParallelBindingsIterDomains( - GpuLower& lower, + GpuLower* lower, const std::vector& used_tvs); using ParallelExtentMap = - std::unordered_map, TypeHash>; + std::unordered_map, TypeHash>; //! Returns the extents of all parallel binding iterdomains corresponding //! to each parallel type. std::unique_ptr getParallelIterExtents( - GpuLower& lower, std::vector& parallel_binding_ids); //! Returns the simplified set of extents necessary for launch parameter //! binding. std::unique_ptr getSimplifiedParallelIterExtents( - GpuLower& lower, + GpuLower* lower, std::vector& parallel_binding_ids); //! Returns the symbolic or constant extetns of warp padded parallel //! iterdomains in the given vector. std::unique_ptr getWarpPaddedExtentsInfo( - GpuLower& lower, + kir::Kernel* lower, std::vector& parallel_binding_ids); -//! Returns the position information of vectorized input/output tensors -//! in the given fusion. -std::unique_ptr getVectorizedTensorValidationInfo( - Fusion* fusion, - GpuLower& lower); - -// TODO(kir): rewrite in terms of Kernel tensors void validateVectorizedTensors( - Fusion* fusion, + kir::Kernel* kernel, const at::ArrayRef& inputs, const std::vector& outputs, - GpuLower& lower, caching::ExecutorCompileTimeInfoCache* data_cache, kir::ExpressionEvaluator& expr_eval); diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index ced4b59a783..5630743b6f6 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index d9d71e53c41..be686c0d943 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -8,10 +8,9 @@ #include #include #include +#include #include -#include - namespace torch { namespace jit { namespace fuser { @@ -37,13 +36,7 @@ void swap(Fusion& a, Fusion& b) noexcept { using std::swap; - // Swap the content - swap(a.val_set_, b.val_set_); - swap(a.expr_set_, b.expr_set_); - swap(a.val_deque_, b.val_deque_); - - swap(a.val_type_name_map_, b.val_type_name_map_); - swap(a.expr_name_counter_, b.expr_name_counter_); + swap(static_cast(a), static_cast(b)); swap(a.inputs_, b.inputs_); swap(a.outputs_, b.outputs_); @@ -51,27 +44,6 @@ void swap(Fusion& a, Fusion& b) noexcept { swap(a.io_alias_, b.io_alias_); swap(a.permuted_input_map_, b.permuted_input_map_); swap(a.permuted_output_map_, b.permuted_output_map_); - - // Fixup the Statement::fusion_ links for a - for (auto val : a.val_set_) { - val->fusion_ = &a; - } - for (auto expr : a.expr_set_) { - expr->fusion_ = &a; - } - - // Fixup the Statement::fusion_ links for b - for (auto val : b.val_set_) { - val->fusion_ = &b; - } - for (auto expr : b.expr_set_) { - expr->fusion_ = &b; - } -} - -Fusion::Fusion(const Fusion& other) { - FUSER_PERF_SCOPE("Fusion copy"); - Fusion::copy(&other, this); } std::unique_ptr Fusion::segment( @@ -82,28 +54,13 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - IrCloner ir_cloner(to); + auto ir_cloner = IrContainer::copy(from, to); - for (auto val : from->val_set_) { - to->val_set_.insert(ir_cloner.clone(val)); - } - - for (auto expr : from->expr_set_) { - to->expr_set_.insert(ir_cloner.clone(expr)); - } - - for (auto val : from->val_deque_) { - to->val_deque_.push_back(ir_cloner.clone(val)); - } - - for (auto val : from->val_set_) { + for (auto val : from->vals_) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); } - to->val_type_name_map_ = from->val_type_name_map_; - to->expr_name_counter_ = from->expr_name_counter_; - to->inputs_ = ir_cloner.clone(from->inputs_); to->outputs_ = ir_cloner.clone(from->outputs_); @@ -117,9 +74,22 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->permuted_input_map_ = from->permuted_input_map_; to->permuted_output_map_ = from->permuted_output_map_; + to->all_tv_uses_valid_ = from->all_tv_uses_valid_; + // This should never be true on copy, but copying for completeness. + to->is_during_update_uses_ = from->is_during_update_uses_; + return ir_cloner; } +// Clang tidy complains when using default constructor for IrContainer instead +// of copy constructor. Fusion::copy has a call to IrContainer::copy, so it's +// redundant to use the IrContainer copy constructor, but it is harmless since +// Fusion::copy starts by calling clear(). +Fusion::Fusion(const Fusion& other) : IrContainer(other) { + FUSER_PERF_SCOPE("Fusion copy"); + Fusion::copy(&other, this); +} + Fusion::Fusion(Fusion&& other) noexcept { FUSER_PERF_SCOPE("Fusion move"); swap(*this, other); @@ -147,36 +117,22 @@ Fusion::~Fusion() { void Fusion::clear() noexcept { FUSER_PERF_SCOPE("Fusion clear"); - // Free the owned values - for (auto ptr : val_set_) { - delete ptr; - } - - // Free the owned expressions - for (auto ptr : expr_set_) { - delete ptr; - } - - val_set_.clear(); - val_deque_.clear(); - expr_set_.clear(); - - for (auto& kv : val_type_name_map_) { - kv.second = 0; - } - - expr_name_counter_ = 0; + IrContainer::clear(); inputs_.clear(); outputs_.clear(); io_alias_.clear(); + permuted_input_map_.clear(); permuted_output_map_.clear(); + + all_tv_uses_valid_ = false; + is_during_update_uses_ = false; } void Fusion::removeExpr(Expr* expr) { - assertInFusion(expr, "Cannot remove expr "); + assertInContainer(expr, "Cannot remove expr "); // If we hit this error too frequently, we could lighten the restrictions so // that removing something that doesn't exist simply does nothing. For now, // we're going with the strictest model which errors. @@ -194,13 +150,11 @@ void Fusion::removeExpr(Expr* expr) { } } - expr_set_.erase(expr); - - delete expr; + IrContainer::removeExpr(expr); } void Fusion::removeVal(Val* val) { - assertInFusion(val, "Cannot remove val "); + assertInContainer(val, "Cannot remove val "); TORCH_CHECK( !val->isFusionInput(), @@ -213,22 +167,14 @@ void Fusion::removeVal(Val* val) { if (orig != nullptr) removeExpr(val->definition()); - for (Expr* use : unordered_uses(val)) + for (Expr* use : unordered_uses(val)) { removeExpr(use); - - val_set_.erase(val); - - for (auto it = val_deque_.begin(); it != val_deque_.end(); it++) - if (*it == val) { - val_deque_.erase(it); - break; - } - - delete val; + } + IrContainer::removeVal(val); } void Fusion::addInput(Val* input) { - assertInFusion(input, "Cannot register input "); + assertInContainer(input, "Cannot register input "); if (input->getValType().value() == ValType::TensorView) { auto tv = input->as(); @@ -242,7 +188,7 @@ void Fusion::addInput(Val* input) { } void Fusion::addOutput(Val* output) { - assertInFusion(output, "Cannot register output "); + assertInContainer(output, "Cannot register output "); if (output->getValType().value() == ValType::TensorView) { auto tv = output->as(); tv->setMemoryType(MemoryType::Global); @@ -307,27 +253,8 @@ void Fusion::replaceOutput(Val* output, Val* replacement) { } } -bool Fusion::inFusion(const Statement* stmt) const { - bool in_fusion = stmt->fusion() == this; - Statement* nonconst_stmt = const_cast(stmt); // NOLINT - - if (stmt->isExpr()) { - in_fusion &= expr_set_.find(nonconst_stmt->as()) != expr_set_.end(); - } - if (stmt->isVal()) { - in_fusion &= val_set_.find(nonconst_stmt->as()) != val_set_.end(); - } - - return in_fusion; -} - -void Fusion::assertInFusion(const Statement* stmt, const std::string& msg) - const { - TORCH_CHECK(inFusion(stmt), msg, " it was not found in the active fusion."); -} - std::vector Fusion::exprs() { - return ExprSort::getExprs(this); + return StmtSort::getExprs(this); } std::vector Fusion::inputsOf(Val* val) { @@ -341,12 +268,24 @@ void Fusion::validateInputs() { all_inputs.insert(input); } } + + std::unordered_set input_dims; + auto inp_tvs = ir_utils::filterByType(inputs()); + for (auto tv : inp_tvs) { + for (auto id : tv->getMaybeRFactorDomain()) { + input_dims.emplace(id->extent()); + } + } for (Val* input : all_inputs) { if (!input->isConstScalar()) { TORCH_CHECK( - hasInput(input) || inFusion(input), + input->isFusionInput() || + // TODO: Switch: + inContainer(input), + // to: input_dims.find(input) != input_dims.end(), + // https://github.com/csarofeen/pytorch/issues/1365 "Could not figure out how ", - input, + input->toString(), " is generated, however it was not specified as an input."); } } @@ -367,6 +306,10 @@ void Fusion::print() { void Fusion::printKernel() { FUSER_PERF_SCOPE("Fusion::printKernel"); + TORCH_INTERNAL_ASSERT( + !this->isA(), + "Cannot \"print kernel\" of a kernel container. ", + "This would require lowering during lowering."); std::cout << codegen::generateCudaKernel(GpuLower(this).kernel()); } @@ -394,7 +337,7 @@ void Fusion::printMath(bool from_outputs_only) { leaf_vals.push_back(val); } } - exprs_for_print = ExprSort::getExprs(this, leaf_vals); + exprs_for_print = StmtSort::getExprs(this, leaf_vals); } std::cout << "\n%kernel_math {\n"; @@ -412,33 +355,36 @@ void Fusion::printTransforms() { t_exprs.handle(this); } -StmtNameType Fusion::registerVal(Val* val) { +void Fusion::registerVal(Val* val) { + if (inContainer(val)) { + return; + } + if (val->fusion()) { - if (val->fusion() != this) { - TORCH_CHECK(false, val, " was not found in the active fusion."); - } - if (inFusion(val)) { - return val->name(); - } + TORCH_CHECK( + val->fusion() == this, val, " was not found in the active fusion."); } - val_set_.emplace(val); - val_deque_.push_back(val); - return getValName(*(val->getValType())); + IrContainer::registerVal(val); } -StmtNameType Fusion::registerExpr(Expr* expr) { +void Fusion::registerExpr(Expr* expr) { + if (inContainer(expr)) { + return; + } + if (expr->fusion()) { - if (expr->fusion() != this) { - TORCH_CHECK(false, expr, " was not found in the active fusion."); - } - if (inFusion(expr)) { - return expr->name(); - } + TORCH_CHECK( + expr->fusion() == this, expr, " was not found in the active fusion."); } + IrContainer::registerExpr(expr); + + bool has_tv = false; + for (Val* input : expr->inputs()) { - assertInFusion(input, "Input to expr is invalid, "); + has_tv = has_tv || input->isA(); + assertInContainer(input, "Input to expr is invalid, "); auto uses_copy = input->uses(); if (std::find(uses_copy.begin(), uses_copy.end(), expr) == uses_copy.end()) { @@ -447,34 +393,25 @@ StmtNameType Fusion::registerExpr(Expr* expr) { } } + // Kernel is the only container type that is non-ssa. This is mainly (maybe + // only) because of initialization expressions which would overwrite tensor + // view definitions. + bool is_ssa = !this->isA(); + for (Val* output : expr->outputs()) { - assertInFusion(output, "Output to expr is invalid, "); - if (output->definition() != nullptr) { + has_tv = has_tv || output->isA(); + assertInContainer(output, "Output to expr is invalid, "); + if (output->definition() != nullptr && is_ssa) { removeExpr(output->definition()); } - output->setDefinition(expr); + if (is_ssa || (!is_ssa && output->definition() == nullptr)) { + output->setDefinition(expr); + } } - expr_set_.emplace(expr); - - resetTvUses(); - return getExprName(); -} - -StmtNameType Fusion::registerStatement(Statement* stmt) { - if (inFusion(stmt)) - return stmt->name(); - - if (stmt->isVal()) { - return registerVal(stmt->as()); - } else if (stmt->isExpr()) { - return registerExpr(stmt->as()); + if (has_tv) { + resetTvUses(); } - - TORCH_INTERNAL_ASSERT( - false, - "Could not register statement as Fusion could not recognize its type."); - return kInvalidStmName; } void Fusion::resetTvUses() { @@ -484,8 +421,8 @@ void Fusion::resetTvUses() { // getExprs only uses definition, so even if we've modified uses already to // remove dead exprs, this could reinsert them. getExprs is also boundeds by // inputs as registered inputs will return nullptr as their definition. - const auto all_tvs = ir_utils::filterByType(val_set_); - const auto used_exprs = ExprSort::getExprs(this); + const auto all_tvs = ir_utils::filterByType(vals_); + const auto used_exprs = StmtSort::getExprs(this); for (auto tv : all_tvs) { tv->setUses({}); @@ -507,14 +444,6 @@ void Fusion::resetTvUses() { is_during_update_uses_ = false; } -const std::unordered_set& Fusion::vals() const noexcept { - return val_set_; -} - -const std::deque& Fusion::deterministic_vals() const noexcept { - return val_deque_; -} - std::vector Fusion::usedMathVals() { // Note that using fusion->inputs() as the argument for the first // parameter of getAllValsBetween does not grab all used vals as @@ -553,37 +482,15 @@ std::vector Fusion::usedMathVals() { return used_math_vals; } -const std::unordered_set& Fusion::unordered_exprs() const noexcept { - return expr_set_; -} - std::unordered_set Fusion::unordered_uses(Val* val) const { return std::unordered_set(val->uses().begin(), val->uses().end()); } Expr* Fusion::definition(const Val* val) const { - assertInFusion(val, "Cannot detect the definition of val, "); + assertInContainer(val, "Cannot detect the definition of val, "); return val->definition(); } -bool Fusion::hasInput(const Val* val) const { - assertInFusion(val, "Cannot check if val is an input, "); - return val->isFusionInput(); -} - -bool Fusion::hasOutput(const Val* val) const { - assertInFusion(val, "Cannot check if val is an output, "); - return val->isFusionOutput(); -} - -StmtNameType Fusion::getValName(ValType vtype) { - return val_type_name_map_[vtype]++; -} - -StmtNameType Fusion::getExprName() { - return expr_name_counter_++; -} - // Indicate to kernel to set itself up to generate random numbers bool Fusion::isStochastic() { for (auto expr : exprs()) @@ -593,28 +500,6 @@ bool Fusion::isStochastic() { return false; } -bool Fusion::hasReduction() { - FUSER_PERF_SCOPE("Fusion::hasReduction"); - - for (auto expr : exprs()) - for (auto out : expr->outputs()) - if (out->getValType() == ValType::TensorView) - if (out->as()->hasReduction()) - return true; - - return false; -} - -bool Fusion::hasWelford() { - FUSER_PERF_SCOPE("Fusion::hasWelford"); - for (auto expr : exprs()) { - if (expr->isA()) { - return true; - } - } - return false; -} - std::vector Fusion::getTerminatingOutputs() { FUSER_PERF_SCOPE("getTerminatingOutputs"); diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index c892bd8171c..2e76e00896b 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -1,10 +1,11 @@ #pragma once #include +#include #include -#include #include +#include #include #include @@ -69,14 +70,14 @@ class TORCH_CUDA_CU_API FusionGuard { //! Fusion is mutable but unique. Nodes cannot be copied in any way from one //! Fusion to another. If anything like that is desired, it would require -//! duplicating all associated values and exprs. Fusion is considered to SSA, +//! duplicating all associated values and exprs. Fusion is considered to be SSA, //! though this could also change in the future if there is a good reason to do //! so. //! //! The Fusion owns the whole IR graph (Vals and Exprs) //! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API Fusion final { +class TORCH_CUDA_CU_API Fusion : public IrContainer { typedef std::unordered_map> PermutationMap; public: @@ -96,45 +97,30 @@ class TORCH_CUDA_CU_API Fusion final { //! Break dependency chains associated with Expr, remove references to expr //! delete expr - void removeExpr(Expr* expr); + void removeExpr(Expr* expr) override; //! Completely remove val from the fusion, break all dependencies associated //! with it - void removeVal(Val* val); + void removeVal(Val* val) override; //! Register input as an input of the fusion - // TODO: Rename to register void addInput(Val* input); //! Register output as an output of the fusion - // TODO: Rename to register void addOutput(Val* output); //! Register output as an output of the fusion - // TODO: Rename to register void addOutput(WelfordResult& output); //! Deregister input as an input of the fusion - // TODO: Rename to register void removeInput(Val* input); //! Deregister output as an output of the fusion - // TODO: Rename to register void removeOutput(Val* output); //! Replace output with another value void replaceOutput(Val* output, Val* replacement); - //! Clear Expr's from TV uses that are not required to produce outputs from - //! inputs - void resetTvUses(); - - //! Check if stmt is properly registered with this fusion - bool inFusion(const Statement* stmt) const; - - //! Throw an error if stmt is not in this fusion - void assertInFusion(const Statement* stmt, const std::string& msg = "") const; - //! Assert that all leaves found from outputs are registered as an input void validateInputs(); @@ -151,17 +137,6 @@ class TORCH_CUDA_CU_API Fusion final { //! Lower the fusion and print a kernel void printKernel(); - //! Register the Val with this fusion - StmtNameType registerVal(Val* val); - - //! Register expr with this fusion. - //! When we register an expression, we want to update the dependency tracking - //! of Vals. We add expr to our general expr_set_, - StmtNameType registerExpr(Expr* expr); - - //! Register stmt with this fusion - StmtNameType registerStatement(Statement* stmt); - //! Return a list of topologically sorted expressions. This only includes //! exprs required to genereate registered outputs. std::vector exprs(); @@ -169,12 +144,6 @@ class TORCH_CUDA_CU_API Fusion final { //! Return a vector of fusion inputs that feed this Val std::vector inputsOf(Val* val); - //! Return the set of Vals registered with this fusion - const std::unordered_set& vals() const noexcept; - - //! Return in insertion order - const std::deque& deterministic_vals() const noexcept; - //! Return all Vals in math expressions that cannot be eliminated. //! //! It is generally equivalent to vals that are used to generate @@ -183,11 +152,6 @@ class TORCH_CUDA_CU_API Fusion final { //! also included as they must show up in the final code. std::vector usedMathVals(); - //! Return the set of Exprs registered with this fusion. Warning: This will - //! return exprs outside inputs/outputs, so can be unsafe for use with - //! segmented fusions. - const std::unordered_set& unordered_exprs() const noexcept; - //! Return all Exprs that use val std::unordered_set unordered_uses(Val* val) const; @@ -197,12 +161,6 @@ class TORCH_CUDA_CU_API Fusion final { //! Indicate to kernel to set itself up to generate random numbers bool isStochastic(); - //! Indicate that the fusion contains reduction operations - bool hasReduction(); - - //! Indicate that the fusion contains welford operations - bool hasWelford(); - //! Run fusion segmentation algorithm to create a segmented fusion std::unique_ptr segment( const at::ArrayRef& inputs); @@ -217,9 +175,6 @@ class TORCH_CUDA_CU_API Fusion final { std::vector getTerminatingOutputs(); - bool hasInput(const Val* val) const; - bool hasOutput(const Val* val) const; - // Aliasing output to input value, this is a WAR to allow inplace update on // input tensor. // Note: this is not always safe and should be used with extra caution. @@ -262,36 +217,40 @@ class TORCH_CUDA_CU_API Fusion final { return is_during_update_uses_; } + const auto& ioAlias() const { + return io_alias_; + } + protected: friend SegmentCandidateFinder; friend SegmentedFusion; friend class TranslateApplicableWelford; + friend Val; static IrCloner copy(const Fusion* from, Fusion* to); - private: - // Return an int that monotonically increases for each val/expr, some are - // explicitly incremented by type. - StmtNameType getValName(ValType vtype); - StmtNameType getExprName(); + //! Register the Val with this fusion + virtual void registerVal(Val* val) override; + + //! Register expr with this fusion. + //! When we register an expression, we want to update the dependency tracking + //! of Vals. If this container is a not a Kernel, it will remove previous + //! definitions of outputs and register this Expr as the definition. Otherwise + //! will update definition if not previously set, but will not remove old + //! definitions. + virtual void registerExpr(Expr* expr) override; + //! Clear Expr's from TV uses that are not required to produce outputs from + //! inputs. Only other place this is used (other than Fusion) is in + //! Val::uses() + void resetTvUses(); + + private: // Determine if the two values are compatible for aliasing // Same DataType, ValType, and number of dimensions bool isAliasCompatible(Val* left, Val* right); private: - // Sets of all Vals/Exprs registered with this fusion - // (val_deque_ is not owning the objects) - std::unordered_set val_set_; - std::deque val_deque_; - std::unordered_set expr_set_; - - // Values names counters - std::unordered_map val_type_name_map_; - - // Expression names counter - StmtNameType expr_name_counter_ = 0; - // Fusion inputs and outputs std::vector inputs_; std::vector outputs_; diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp index 9ff25780814..fd7b6fc502a 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp @@ -322,7 +322,7 @@ void SegmentedFusion::draw() { for (auto group : groups()) { for (auto expr : group->exprs()) { - if (ir_utils::isTVOp(expr)) { + if (ir_utils::isTvOp(expr)) { expr_color_map[expr] = group_index; } } @@ -659,8 +659,8 @@ TensorView* castIntermediateValueInCompleteFusion( } // Create the actual domain and tv. - return new TensorView( - new TensorDomain( + return IrBuilder::create( + IrBuilder::create( new_root_domain, std::vector(new_root_domain.size(), true)), data_type); }; @@ -680,8 +680,8 @@ TensorView* castIntermediateValueInCompleteFusion( } // Insert the cast ops. - new UnaryOp(UnaryOpType::Cast, half_precision_tv, original_tv); - new UnaryOp(UnaryOpType::Cast, fp32_tv, half_precision_tv); + IrBuilder::create(UnaryOpType::Cast, half_precision_tv, original_tv); + IrBuilder::create(UnaryOpType::Cast, fp32_tv, half_precision_tv); // Return the new tv to replace original tv with // on the segmented edges. @@ -1740,9 +1740,10 @@ TranslateApplicableWelford::TranslateApplicableWelford( Fusion* fusion, const at::ArrayRef& runtime_inputs) : runtime_inputs_(runtime_inputs) { + auto exprs = fusion->exprs(); std::vector orignal_welfords( - ir_utils::filterByType(fusion->unordered_exprs()).begin(), - ir_utils::filterByType(fusion->unordered_exprs()).end()); + ir_utils::filterByType(exprs).begin(), + ir_utils::filterByType(exprs).end()); if (wouldTranslateToPersistent(orignal_welfords)) { for (auto welford : orignal_welfords) { @@ -1829,6 +1830,14 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( [&original_to_test_map](auto welford) { return original_to_test_map.clone(welford); }); + // Copied welfords will be invalidated on translation, but Vals will be + // reused, keep a reference to them. + std::vector welford_avgs; + std::vector welford_vars; + for (auto welford : copied_welfords) { + welford_avgs.push_back(welford->outAvg()); + welford_vars.push_back(welford->outVar()); + } // Translate the welford ops for (auto welford_to_translate : copied_welfords) { @@ -1860,6 +1869,21 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( return original_to_test_map.clone(out); }); + // If only average is used from welford, we should still translate, but we + // might not detect persistence if variance isn't actually used/marked as an + // output in the test. + for (auto outs_i : c10::irange(welford_avgs.size())) { + auto avg = welford_avgs[outs_i]; + auto var = welford_vars[outs_i]; + if (avg->uses().empty()) { + test_group_outputs_.push_back(avg); + } + + if (var->uses().empty()) { + test_group_outputs_.push_back(var); + } + } + // Temporarily localize test copy around // the group boundary FusionSegmentGuard fsg( @@ -1900,7 +1924,7 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { // Create scalar version of the feature element // counting. - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(1); std::vector broadcast_mask(in_root.size(), false); for (const auto i : c10::irange(in_root.size())) { if (out_root[i]->isReduction()) { @@ -1913,7 +1937,7 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { // Build a normalization expression group that is // equivalent to a welford operation. auto x_sum = sum(in_val, red_axes); - new BinaryOp(BinaryOpType::Div, out_avg, x_sum, num_features); + IrBuilder::create(BinaryOpType::Div, out_avg, x_sum, num_features); // welford.avg may be broadcast. Reuse it if found. TensorView* x_avg_bcast = nullptr; for (auto& use_expr : out_avg->uses()) { @@ -1949,8 +1973,12 @@ void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) { } auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); - new ReductionOp(BinaryOpType::Add, new Double(0.0), out_var, x_mean_sub_pow); - new UnaryOp(UnaryOpType::Set, out_N, num_features); + IrBuilder::create( + BinaryOpType::Add, + IrBuilder::create(0.0), + out_var, + x_mean_sub_pow); + IrBuilder::create(UnaryOpType::Set, out_N, num_features); // out_avg, out_N are now outputs of a pointwise ops and we // need to clear out its reduction domains. @@ -2687,14 +2715,20 @@ void SegmentCandidateFinder::findSegments() { } } + auto reduction_ops = + ir_utils::getReductionOps(segmented_fusion_->completeFusion()); + auto welford_ops = ir_utils::filterByType(reduction_ops); + if (options_.run_translate_welford && - segmented_fusion_->completeFusion()->hasWelford()) { + (welford_ops.begin() != welford_ops.end())) { TranslateApplicableWelford::run(segmented_fusion_.get(), runtime_inputs_); } for (auto group : groups()) { - // Set heuristics in case single reduction kernels were left out - group->setHeuristic(deriveHeuristic(group)); + if (!group->outputs().empty()) { + // Set heuristics in case single reduction kernels were left out + group->setHeuristic(deriveHeuristic(group)); + } } // Remove all scalar edges since they do not represent actual @@ -2913,7 +2947,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) { group->input_vals = IterVisitor::getInputsTo(group->inputs()); // Grab all expressions needed to produce to_visit - auto input_exprs = ExprSort::getExprs(completeFusion(), to_visit); + auto input_exprs = StmtSort::getExprs(completeFusion(), to_visit); // Insert those expressions at the beginning of the group group->exprs_.insert( @@ -3102,7 +3136,7 @@ void SegmentedFusion::annotateFP16IntermediateTensors() { } } -TORCH_CUDA_CU_API std::string toString( +std::string toString( const SegmentCandidateFinderOptions& segment_options) { std::stringstream ss; ss << "segmentation phases {\n"; diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h index 61fa966348e..63124839fc1 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h @@ -288,11 +288,11 @@ class TORCH_CUDA_CU_API SegmentedFusion { } Val* findAlias(Val* val) const { - Val* alias_val = nullptr; - if (complete_fusion_->io_alias_.count(val) != 0) { - alias_val = complete_fusion_->io_alias_[val]; + auto alias_it = complete_fusion_->ioAlias().find(val); + if (alias_it != complete_fusion_->ioAlias().end()) { + return alias_it->second; } - return alias_val; + return nullptr; } //! Make a clone of the group and convert to fusion diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index 47c0316abda..b2d1f893ba6 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include #include @@ -19,6 +21,7 @@ #include #include +#include #include #include @@ -46,6 +49,13 @@ bool usedOnlyInDtype(Value* v) { Value* broadcastSizes(at::ArrayRef sizes) { AT_ASSERT(!sizes.empty()); Graph* graph = sizes[0]->owningGraph(); + Node* insertion_point = sizes[0]->node()->next(); + for (size_t i = 1; i < sizes.size(); i++) { + if (insertion_point->isBefore(sizes[i]->node()->next())) { + insertion_point = sizes[i]->node()->next(); + } + } + WithInsertPoint guard(insertion_point); Node* broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); broadcast_n->output()->setType(ListType::ofInts()); @@ -66,9 +76,13 @@ Value* createConditionalConstant(Node* profile_ivalue) { auto int_list = profile_ivalue->is(Symbol::attr("profiled_bool_list")); std::vector bool_list(int_list.begin(), int_list.end()); val = IValue(bool_list); - } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_size"))) { + } else if (profile_ivalue->hasAttribute( + Symbol::attr("profiled_reduction_size"))) { // int[] - val = IValue(profile_ivalue->is(Symbol::attr("profiled_size"))); + val = IValue(profile_ivalue->is(Symbol::attr("profiled_reduction_size"))); + } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_view_size"))) { + // int[] + val = IValue(profile_ivalue->is(Symbol::attr("profiled_view_size"))); } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool"))) { // bool val = IValue( @@ -101,6 +115,7 @@ struct CudaGraphFuser { std::unique_ptr aliasDb_; std::shared_ptr graph_; Symbol kind_ = prim::CudaFusionGroup; + std::unordered_map fusion_value_to_runtime_shape_; // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. // The specific limit is a function of constant memory size, amount available @@ -764,9 +779,11 @@ struct CudaGraphFuser { // longer valid so we rescan the new FusionGroup for more fusions... return std::make_pair(fusion_group.value()->reverseIterator(), true); } - // horizontal fusion only applies on tensor inputs + + // horizontal fusion only applies on non-scalar tensor inputs if (getHorizontalFusion() && - producer->type()->isSubtypeOf(*TensorType::get())) { + producer->type()->isSubtypeOf(*TensorType::get()) && + !is_cpu_scalar(*producer->type()->cast())) { // fusing nodes sharing inputs, this could save memory bandwidth by // reducing number of tensor read. for (const auto& u : producer->uses()) { @@ -838,6 +855,7 @@ struct CudaGraphFuser { // Builds up expressions that compute shapes of all intermediates (and // outputs) of the fusion group, based on the sizes of inputs. You should run // DCE to remove those that you end up not using. + // TODO: Add shape support for view, reshape, unsqueeze, and squeeze std::unordered_map buildShapeExpressions(Node* fusion_group) { WithInsertPoint insert_guard{fusion_group->next()}; std::unordered_map shape_of; @@ -850,7 +868,9 @@ struct CudaGraphFuser { AT_ASSERT(inputs.size() == sinputs.size()); for (const auto i : c10::irange(inputs.size())) { if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { - shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]}); + auto sinput_value = graph->insert(aten::size, {inputs[i]}); + shape_of[sinputs[i]] = sinput_value; + sinput_value->node()->moveBefore(fusion_group); } } @@ -869,6 +889,26 @@ struct CudaGraphFuser { } } + // Place all the shape expressions for intermediates in fusion + // before the CudaFusionGroup + graph->setInsertPoint(fusion_group); + + // hmmm, do I need to setInsertPoint... + const auto map_inputs = [&](Value* v) -> Value* { + // if constant ever has an input, it has to come from + // profile_ivalue dependency + if (v->node()->kind() == prim::Param && + fusion_group->input(v->offset())->node()->kind() == + prim::profile_ivalue) { + // we need to map it along profile_ivalue dependency + return fusion_group->input(v->offset()); + } else { + throw std::runtime_error( + std::string("unexpected input from node") + + v->node()->kind().toDisplayString()); + } + }; + for (Node* n : subgraph->nodes()) { // XXX: Use of shape_of.emplace is crucial to the output shape // optimization! @@ -912,21 +952,6 @@ struct CudaGraphFuser { n->input(2)->node()->kind() == prim::Constant, "only supports reduction axes and keepdim being constant"); - // hmmm, do I need to setInsertPoint... - const auto map_inputs = [&](Value* v) -> Value* { - // if constant ever has an input, it has to come from - // profile_ivalue dependency - if (v->node()->kind() == prim::Param && - fusion_group->input(v->offset())->node()->kind() == - prim::profile_ivalue) { - // we need to map it along profile_ivalue dependency - return fusion_group->input(v->offset()); - } else { - throw std::runtime_error( - std::string("unexpected input from node") + - v->node()->kind().toDisplayString()); - } - }; Node* in1_const = graph->createClone(n->input(1)->node(), map_inputs); graph->insertNode(in1_const); Node* in2_const = graph->createClone(n->input(2)->node(), map_inputs); @@ -1000,6 +1025,57 @@ struct CudaGraphFuser { } continue; } + if (n->kind() == aten::native_dropout) { + TORCH_INTERNAL_ASSERT( + shape_of.count(n->input(0)) > 0, + "buildShapeExpressions failed at accessing input shapes"); + shape_of.emplace(n->output(0), shape_of.at(n->input(0))); + shape_of.emplace(n->output(1), shape_of.at(n->input(0))); + continue; + } + if (n->kind() == prim::unsqueeze_copy) { + TORCH_INTERNAL_ASSERT( + shape_of.count(n->input(0)) > 0, + "buildShapeExpressions failed at accessing input shapes"); + TORCH_INTERNAL_ASSERT( + n->input(1)->node()->kind() == prim::Constant, + "only supports unsqueeze axes being constant"); + Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs); + graph->insertNode(dim_const); + std::vector inputs = { + shape_of.at(n->input(0)), dim_const->output()}; + Node* size_node = graph->insertNode(graph->create( + Symbol::fromQualString("prim::infer_unsqueeze_size"), inputs, 1)); + Value* size = size_node->output(0); + size->setType(ListType::ofInts()); + shape_of.emplace(n->output(), size); + continue; + } + if (n->kind() == prim::squeeze_copy) { + TORCH_INTERNAL_ASSERT( + shape_of.count(n->input(0)) > 0, + "buildShapeExpressions failed at accessing input shapes"); + TORCH_INTERNAL_ASSERT( + n->inputs().size() == 2 || n->inputs().size() == 1, + "prim::squeeze_copy expects one or two inputs"); + std::vector inputs = {shape_of.at(n->input(0))}; + + if (n->inputs().size() == 2) { + TORCH_INTERNAL_ASSERT( + n->input(1)->node()->kind() == prim::Constant, + "only supports squeeze axes being constant"); + Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs); + graph->insertNode(dim_const); + inputs.push_back(dim_const->output()); + } + Node* size_node = graph->insertNode(graph->create( + Symbol::fromQualString("prim::infer_squeeze_size"), inputs, 1)); + Value* size = size_node->output(0); + size->setType(ListType::ofInts()); + shape_of.emplace(n->output(), size); + continue; + } + auto tensor_inputs = filter(n->inputs(), [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); @@ -1025,8 +1101,9 @@ struct CudaGraphFuser { // TODO: failure in buildShapeExpressions should not break fusion execution, // we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize. GRAPH_DEBUG("before build shape expression: ", *graph_); - auto shape_of = buildShapeExpressions(fusion_group); + fusion_value_to_runtime_shape_ = buildShapeExpressions(fusion_group); GRAPH_DEBUG("after build shape expression: ", *graph_); + auto outputs = fusion_group->outputs().vec(); auto soutputs = subgraph->outputs().vec(); // XXX: Iterating in this order is not only good for performance reasons! @@ -1035,12 +1112,14 @@ struct CudaGraphFuser { for (int64_t i = static_cast(outputs.size()) - 1; i >= 0; --i) { auto output = outputs[i]; auto soutput = soutputs[i]; - if (usedOnlyInDtypeAndSize(output) && shape_of.count(soutput) > 0) { + if (usedOnlyInDtypeAndSize(output) && + fusion_value_to_runtime_shape_.count(soutput) > 0) { bool has_dtype = usedInDtype(output); auto uses = output->uses(); for (Use u : uses) { if (u.user->matches("aten::size(Tensor self) -> int[]")) { - u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); + u.user->output()->replaceAllUsesWith( + fusion_value_to_runtime_shape_.at(soutput)); u.user->destroy(); } else if (u.user->matches("prim::dtype(Tensor a) -> int")) { continue; @@ -1286,6 +1365,55 @@ void PeepholeOptimizeShapeExpressions(Block* block) { } } +// view_sizes_runtime is the profiled-ivalue argument for view-size. +// view_sizes_constant_list is the constant list recorded during profiling runs. +Value* guardView( + Node* fusion, + std::unordered_map& fusion_value_to_runtime_size, + Node* versioning_if, + Node* view, + Value* view_sizes_runtime) { + // 1. Get self tensor sizes and view_sizes + auto self_value = view->inputs().front(); + auto self_type = self_value->type()->cast(); + auto self_sizes_constant_list = getTensorSizes(self_type); + + auto view_sizes_constant_list = + constant_as>(view->inputs().back()); + TORCH_INTERNAL_ASSERT(view_sizes_constant_list.has_value()); + + // 2. Get constraints for self tensor and view_sizes + auto constraints = analyzeViewConstraint( + self_sizes_constant_list, view_sizes_constant_list->vec()); + + // 3. Add constraints as constant to graph + auto self_tensor_constraint = fusion->owningGraph()->insertConstant( + IValue(constraints.original_constraint)); + self_tensor_constraint->node()->moveBefore(versioning_if); + auto view_sizes_constraint = + fusion->owningGraph()->insertConstant(IValue(constraints.new_constraint)); + view_sizes_constraint->node()->moveBefore(versioning_if); + + // 4. Create CudaFusionViewGuard using input tensor, profile_ivalue + // for view_sizes list, and constraints + TORCH_INTERNAL_ASSERT( + fusion_value_to_runtime_size.find(self_value) != + fusion_value_to_runtime_size.end(), + "Failed to find runtime size for fusion value:\t", + self_value->node()->kind().toDisplayString()); + Node* viewcheck_node = + fusion->owningGraph() + ->create( + c10::Symbol::fromQualString("prim::CudaFusionViewGuard"), + {fusion_value_to_runtime_size.at(self_value), + view_sizes_runtime, + self_tensor_constraint, + view_sizes_constraint}, + 1) + ->insertBefore(versioning_if); + return viewcheck_node->output(); +} + //! [ Note -- CudaFusionGuard implementation ] //! //! shamelessly copying code from NNC (tensorexpr_fuser) with very little @@ -1324,7 +1452,9 @@ void PeepholeOptimizeShapeExpressions(Block* block) { //! //! TODO: we also need to assert/check reduction axes and replace it with //! constants in `CudaFusionGroup` -void guardFusionGroup(Node* fusion) { +void guardFusionGroup( + Node* fusion, + std::unordered_map& fusion_value_to_runtime_size) { // Fixup types of the subgraph inputs std::vector guard_types; std::vector tensor_inputs_to_check; @@ -1375,10 +1505,12 @@ void guardFusionGroup(Node* fusion) { versioning_if->insertAfter(typecheck_node); + auto fusion_graph = fusion->g(attr::Subgraph); + std::vector check_flags = {}; + // Fill in the false block. It should contain the unoptimized // copy of the fused subgraph, unless we have conditional constants from // profiled_ivalue; - auto fusion_graph = fusion->g(attr::Subgraph); std::shared_ptr fb_graph; // resource holder; // Restore the dependency for constant introduced by profiled_ivalue within // the graph. @@ -1425,11 +1557,10 @@ void guardFusionGroup(Node* fusion) { // 2. REMOVE conditional constant dependency in fusion group size_t compensation = 0; - // get a constant false, which is used by `and` pattern later + // get a constant true, which is used by `and` pattern later auto const_true = fusion->owningGraph()->insertConstant(IValue(true)); const_true->node()->moveBefore(versioning_if); - std::vector check_flags = {}; for (const auto& original_offset : profiled_ivalue_indices) { size_t offset = original_offset - compensation; @@ -1457,7 +1588,7 @@ void guardFusionGroup(Node* fusion) { ->insertBefore(versioning_if) ->output(); } else if (fusion->input(offset)->node()->hasAttribute( - Symbol::attr("profiled_size"))) { + Symbol::attr("profiled_reduction_size"))) { // TODO(profile_size): check sizes here with special size comparison op // TORCH_INTERNAL_ASSERT(false, "not implemented yet"); ivalue_check = @@ -1468,6 +1599,28 @@ void guardFusionGroup(Node* fusion) { 1) ->insertBefore(versioning_if) ->output(); + } else if (fusion->input(offset)->node()->hasAttribute( + Symbol::attr("profiled_view_size"))) { + // TODO: Add support for dynamic split to view guard + + // Path from profile-ivalue to prim::view_copy operation + // profile-ivalue -> Uses: [Constant, CudaFusionGroup] + // Get argument position in CudaFusionGroup + // Get argument in subgraph for CudaFusionGroup + // CudaFusionGroup argument -> Constant List -> prim::view_copy + auto cuda_fusion_group_arg = profiled_ival->uses().back().offset; + auto subgraph_arg = fusion_graph->inputs()[cuda_fusion_group_arg]; + auto constant = subgraph_arg->uses().front().user->output(); + auto view = constant->uses().front().user; + TORCH_INTERNAL_ASSERT( + view->kind() == prim::view_copy || + view->kind() == prim::reshape_copy); + ivalue_check = guardView( + fusion, + fusion_value_to_runtime_size, + versioning_if, + view, + profiled_ival); } else { ivalue_check = fusion->owningGraph() ->create(aten::eq, {profiled_ival, const_o}, 1) @@ -1495,22 +1648,24 @@ void guardFusionGroup(Node* fusion) { fusion_graph->eraseInput(offset); compensation++; } - - if (!check_flags.empty()) { - // attaching output from CudaFusionGuard to profile ivalue checks - check_flags.emplace_back(typecheck_result); - auto graph = fusion->owningGraph(); - auto bool_list_node = - graph->insertNode(graph->createList(BoolType::get(), check_flags)); - bool_list_node->moveBefore(versioning_if); - Value* bool_list = bool_list_node->output(); - // new typecheck_result - typecheck_result = graph->insert(aten::all, {bool_list}); - typecheck_result->node()->moveBefore(versioning_if); - } // update graph in fusion node fusion->g_(attr::Subgraph, fusion_graph); - } else { + } + + if (!check_flags.empty()) { + // attaching output from CudaFusionGuard to profile ivalue checks + check_flags.emplace_back(typecheck_result); + auto graph = fusion->owningGraph(); + auto bool_list_node = + graph->insertNode(graph->createList(BoolType::get(), check_flags)); + bool_list_node->moveBefore(versioning_if); + Value* bool_list = bool_list_node->output(); + // new typecheck_result + typecheck_result = graph->insert(aten::all, {bool_list}); + typecheck_result->node()->moveBefore(versioning_if); + } + + if (profiled_ivalue_indices.empty()) { WithInsertPoint guard(false_block->return_node()); const auto subgraph_outputs = insertGraph(*fusion->owningGraph(), *fusion_graph, fusion->inputs()); @@ -1536,11 +1691,13 @@ void guardFusionGroup(Node* fusion) { } } -void guardFusionGroups(Block* block) { +void guardFusionGroups( + Block* block, + std::unordered_map& fusion_value_to_runtime_size) { std::vector fusions; for (Node* n : block->nodes()) { for (Block* b : n->blocks()) { - guardFusionGroups(b); + guardFusionGroups(b, fusion_value_to_runtime_size); } if (n->kind() == prim::CudaFusionGroup) { fusions.push_back(n); @@ -1550,7 +1707,7 @@ void guardFusionGroups(Block* block) { // step 1: a. add prim::CudaFusionGuard and fallback logic // b. insert guard logic of profile_ivalue with if block // c. restore conditional constant to non-constant for fallback - guardFusionGroup(fusion); + guardFusionGroup(fusion, fusion_value_to_runtime_size); } } @@ -1918,6 +2075,85 @@ void decomposeLinearOps(Block* block) { } } +// Replace 'operation' with 'operation_copy' to guard alias operations. +// Supports View, Reshape, Squeeze, and Unsqueeze +void replaceAliasOpsWithCopy(std::shared_ptr& graph, Block* block) { + static std::unordered_map op_mapping( + {{aten::view, prim::view_copy}, + {aten::reshape, prim::reshape_copy}, + {aten::squeeze, prim::squeeze_copy}, + {aten::unsqueeze, prim::unsqueeze_copy}}); + + std::vector maybe_alias_nodes; + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + replaceAliasOpsWithCopy(graph, b); + } + if (op_mapping.find(n->kind()) != op_mapping.end()) { + maybe_alias_nodes.push_back(n); + } + } + + auto alias_db = std::make_unique(graph); + for (Node* n : maybe_alias_nodes) { + if (!alias_db->safeToChangeAliasingRelationship( + n->input(0), n->output(0))) { + continue; + } + + WithInsertPoint guard(n); + auto op_copy = + graph->insertNode(graph->create(op_mapping[n->kind()], n->inputs(), 1)); + op_copy->output()->setType(n->output(0)->type()); + + // adding newly created value into alias_db; + alias_db->createValue(op_copy->output()); + + n->output()->replaceAllUsesWith(op_copy->output()); + n->destroy(); + } +} + +// Revert all 'op_copy' with 'op' except in CudaFusionGroup +// e.g., Any non-fused alias operation including within the prim::FallbackGraph +// Supports View, Reshape, Squeeze, and Unsqueeze +void revertAliasCopyOps(std::shared_ptr& graph, Block* block) { + static std::unordered_map op_mapping( + {{prim::view_copy, aten::view}, + {prim::reshape_copy, aten::reshape}, + {prim::squeeze_copy, aten::squeeze}, + {prim::unsqueeze_copy, aten::unsqueeze}}); + + std::vector alias_copy_ops; + for (Node* n : block->nodes()) { + // Allow alias copy ops in CudaFusionGroup + if (n->kind() == prim::CudaFusionGroup) { + continue; + } + // Revert alias copy ops within FallbackGraph + if (n->kind() == prim::FallbackGraph) { + auto subgraph = n->g(attr::Subgraph); + revertAliasCopyOps(subgraph, subgraph->block()); + } + for (Block* b : n->blocks()) { + revertAliasCopyOps(graph, b); + } + // Revert any non-fused alias copy ops + if (op_mapping.find(n->kind()) != op_mapping.end()) { + alias_copy_ops.push_back(n); + } + } + + for (Node* n : alias_copy_ops) { + WithInsertPoint guard(n); + auto reverted_op = + graph->insertNode(graph->create(op_mapping[n->kind()], n->inputs(), 1)); + reverted_op->output()->setType(n->output(0)->type()); + n->output()->replaceAllUsesWith(reverted_op->output()); + n->destroy(); + } +} + // break `conv2d` layer into `conv2d` and `add_optional`. This allows us to fuse // the binary operation without supporting gemm. // Note that we are not breaking `conv2d` layer without bias. @@ -2030,12 +2266,16 @@ void CudaFuseGraph(std::shared_ptr& graph) { decomposeConvOps(graph->block()); GRAPH_DEBUG("After decompose decompose Conv Ops by nvfuser: ", *graph); - CudaGraphFuser(graph->block(), graph).run(); + replaceAliasOpsWithCopy(graph, graph->block()); + GRAPH_DEBUG("replace alias_op with alias_copy by nvfuser: ", *graph); + + CudaGraphFuser cgf(graph->block(), graph); + cgf.run(); GRAPH_DEBUG("After Fusion: ", *graph); // guard input types as well as conditional constants from // aten::profile_ivalue - guardFusionGroups(graph->block()); + guardFusionGroups(graph->block(), cgf.fusion_value_to_runtime_shape_); GRAPH_DEBUG("After Guard Fusion: ", *graph); // mutate `aten::_batch_norm_impl_index` and @@ -2053,6 +2293,10 @@ void CudaFuseGraph(std::shared_ptr& graph) { // optimization targeting AMP removeOutputUsedOnlyInDtype(graph->block()); GRAPH_DEBUG("After removeOutputUsedOnlyInDtype: ", *graph); + + revertAliasCopyOps(graph, graph->block()); + GRAPH_DEBUG("revert alias_copy ops by nvfuser: ", *graph); + // After FuseGraph some common subexpressions may come back EliminateCommonSubexpression(graph); // We might have emitted a fair amount of useless shape propagating code, so diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 39176a60c53..8e151372b75 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -10,9 +10,8 @@ #include #include #include -#include -#include #include +#include #include #include #include @@ -44,9 +43,9 @@ class ContigIDs : public OptInDispatch { using OptInDispatch::handle; // Mark if ids are result of contigous merges - std::unordered_set contig_ids; + std::unordered_set contig_ids; // Given contiguous domain, return all iter domains within its history. - std::unordered_map> + std::unordered_map> within_contig_ids; const std::vector& root_domain_; const std::vector& root_contiguity_; @@ -58,7 +57,7 @@ class ContigIDs : public OptInDispatch { }); } - bool isContig(kir::IterDomain* id) { + bool isContig(IterDomain* id) { return contig_ids.find(id) != contig_ids.end(); } @@ -66,14 +65,11 @@ class ContigIDs : public OptInDispatch { void handle(Split*) override {} void handle(Merge* merge) override { - const auto gpu_lower = GpuLower::current(); - // If either input is non-contiguous so is output. const auto inner = merge->inner(); const auto outer = merge->outer(); - if ((!isContig(gpu_lower->lowerValue(inner)->as()) || - !isContig(gpu_lower->lowerValue(outer)->as()))) { + if (!isContig(inner) || !isContig(outer)) { return; } @@ -136,38 +132,34 @@ class ContigIDs : public OptInDispatch { // If we matched all inputs, the output is contiguous. Only want to keep the // top contig ID, lower ids should be placed in the "within_contig_ids" map // of top id. - auto kir_inner = - gpu_lower->lowerValue(merge->inner())->as(); - auto kir_outer = - gpu_lower->lowerValue(merge->outer())->as(); - auto kir_out = gpu_lower->lowerValue(merge->out())->as(); + auto out = merge->out()->as(); if (ordered_inputs.empty()) { - if (contig_ids.find(kir_inner) != contig_ids.end()) { - contig_ids.erase(kir_inner); + if (contig_ids.find(inner) != contig_ids.end()) { + contig_ids.erase(inner); } - if (contig_ids.find(kir_outer) != contig_ids.end()) { - contig_ids.erase(kir_outer); + if (contig_ids.find(outer) != contig_ids.end()) { + contig_ids.erase(outer); } - contig_ids.emplace(kir_out); + contig_ids.emplace(out); - std::unordered_set within_out; - within_out.emplace(kir_inner); - if (within_contig_ids.find(kir_inner) != within_contig_ids.end()) { - auto in_inner = within_contig_ids.at(kir_inner); + std::unordered_set within_out; + within_out.emplace(inner); + if (within_contig_ids.find(inner) != within_contig_ids.end()) { + auto in_inner = within_contig_ids.at(inner); within_out.insert(in_inner.begin(), in_inner.end()); - within_contig_ids.erase(kir_inner); + within_contig_ids.erase(inner); } - within_out.emplace(kir_outer); - if (within_contig_ids.find(kir_outer) != within_contig_ids.end()) { - auto in_outer = within_contig_ids.at(kir_outer); + within_out.emplace(outer); + if (within_contig_ids.find(outer) != within_contig_ids.end()) { + auto in_outer = within_contig_ids.at(outer); within_out.insert(in_outer.begin(), in_outer.end()); - within_contig_ids.erase(kir_outer); + within_contig_ids.erase(outer); } - within_contig_ids[kir_out] = within_out; + within_contig_ids[out] = within_out; } } @@ -195,8 +187,6 @@ class ContigIDs : public OptInDispatch { " != ", root_contiguity_.size()); - const auto gpu_lower = GpuLower::current(); - for (const auto i : c10::irange(root_domain_.size())) { // If a root domain has halo, can't use merged domain even if // both inputs are contiguous. HaloInfo is also initialized for @@ -204,32 +194,32 @@ class ContigIDs : public OptInDispatch { // RootAxisInfo. This should be safe as no rfactor tensor should // need halo. if (root_contiguity_[i] && - !gpu_lower->haloInfo().getRootAxisInfo(root_domain_[i]).hasHalo()) { - auto kir_root_domain_i = - gpu_lower->lowerValue(root_domain_[i])->as(); - contig_ids.emplace(kir_root_domain_i); - within_contig_ids[kir_root_domain_i] = - std::unordered_set(); + !GpuLower::current() + ->haloInfo() + .getRootAxisInfo(root_domain_[i]) + .hasHalo()) { + auto root_domain_i = root_domain_[i]->as(); + contig_ids.emplace(root_domain_i); + within_contig_ids[root_domain_i] = std::unordered_set(); is_contig_root[root_domain_[i]] = true; } else { is_contig_root[root_domain_[i]] = false; } } - auto exprs = ExprSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()}); + auto exprs = StmtSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()}); for (auto expr : exprs) { handle(expr); } } - const std::unordered_set contigIDs() const { + const std::unordered_set contigIDs() const { return contig_ids; } - const std:: - unordered_map> - withinContigIDs() const { + const std::unordered_map> + withinContigIDs() const { return within_contig_ids; } }; @@ -276,21 +266,18 @@ void updateHaloInfoForReference( // // ref_map: ref-to-consumer in consumer indexing; ref-to-producer in // producer indexing -std::unordered_map getReferenceHaloExtentMap( +std::unordered_map getReferenceHaloExtentMap( const ReferenceTensor& reference, const std::unordered_map& index_map_from_ref) { - const auto gpu_lower = GpuLower::current(); - - const auto& halo_info = gpu_lower->haloInfo(); + const auto& halo_info = GpuLower::current()->haloInfo(); - std::unordered_map reference_halo_extent_map; + std::unordered_map reference_halo_extent_map; // Propagate halo extents of the reference to the consumer or // producer tensor for (auto kv : index_map_from_ref) { - auto ref_id = gpu_lower->lowerValue(kv.first)->as(); - auto producer_or_consumer_id = - gpu_lower->lowerValue(kv.second)->as(); + auto ref_id = kv.first; + auto producer_or_consumer_id = kv.second; auto extent = halo_info.getExtent(ref_id); if (extent != nullptr) { reference_halo_extent_map[producer_or_consumer_id] = extent; @@ -302,7 +289,7 @@ std::unordered_map getReferenceHaloExtentMap( //! Offset of an index of a producer axis with respect to its //! corresponding consumer index -kir::Val* getProducerHaloOffset( +int getProducerHaloOffset( const TensorView* producer_tv, size_t producer_axis, const TensorView* consumer_tv) { @@ -325,41 +312,31 @@ kir::Val* getProducerHaloOffset( const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0); const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - kir::Val* offset = (p_pad->isConst() && c_pad->isConst()) - ? ir_builder.create( - p_pad->value().value() - c_pad->value().value()) - : ir_builder.subExpr(p_pad, c_pad); + auto offset = p_pad - c_pad; // If the consumer is a result of shifting the producer, adjust the // producer index per the offsets argument of the shift op. if (auto shift_op = dynamic_cast(consumer_tv->definition())) { - offset = ir_builder.subExpr( - offset, ir_builder.create(shift_op->offset(producer_axis))); + offset -= shift_op->offset(producer_axis); } return offset; } //! Offset producer index when necessary -kir::Val* getProducerIndexWithHalo( +Val* getProducerIndexWithHalo( const TensorView* producer_tv, size_t producer_axis, - kir::Val* producer_index, + Val* producer_index, const TensorView* consumer_tv) { const auto offset = getProducerHaloOffset(producer_tv, producer_axis, consumer_tv); - if (offset->isZeroInt()) { + if (offset == 0) { return producer_index; } - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - producer_index = ir_builder.addExpr(producer_index, offset); + producer_index = SimplifyingIrBuilder::addExpr(producer_index, offset); return producer_index; } @@ -368,58 +345,58 @@ kir::Val* getProducerIndexWithHalo( //! //! \param consumer_root_axis Position of corresponding consumer axis //! \param consumer_tv Consumer TensorView +//! \param index_map Mappings from consumer or reference to indices +//! \param use_reference_map True when index_map maps reference domains //! \param concrete_to_ref_map Mappings from concrete to reference domains -//! \param ref_index_map Mappings from reference domains to indices -kir::Val* getProducerOffsetWithGather( +Val* getProducerOffsetWithGather( size_t consumer_root_axis, const TensorView* consumer_tv, - const std::unordered_map& concrete_to_ref_map, - const std::unordered_map& ref_index_map) { + const std::unordered_map& index_map, + bool use_reference_map = false, + const std::unordered_map& concrete_to_ref_map = + {}) { const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); const auto gather_expr = dynamic_cast(consumer_tv->definition()); if (gather_expr == nullptr) { - return ir_builder.zeroVal(); + return gpu_lower->kernel()->zeroVal(); } // If the window extent is one, no specific offsetting // is necessary if (consumer_root_axis >= gather_expr->windowShape().size() || - gather_expr->windowShape()[consumer_root_axis]->isOneInt()) { - return ir_builder.zeroVal(); + gather_expr->windowShape()[consumer_root_axis] == 1) { + return gpu_lower->kernel()->zeroVal(); } // Basically, the goal is to build an expression of producer_index + // window_index, so we first need to locate the index expression // that corresponds to the window axis of this producer axis. - // Locate the root IterDomain of the reference that corresponds to the gather - // axis const auto window_axis = gather_expr->gatherAxis(consumer_root_axis); auto window_id = consumer_tv->getRootDomain().at(window_axis); - auto concrete_window_id = - gpu_lower->caIndexMap().getConcreteMappedID(window_id); - auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id); - TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end()); - IterDomain* reference_root_of_gather_axis = concrete_2_ref_it->second; - - // Now that reference_root_of_gather_axis is the IterDomain for the - // window axis, take its corresponding index from the index map - auto window_idx = - ref_index_map.at(gpu_lower->lowerValue(reference_root_of_gather_axis) - ->as()); - - // Positive (or negative) padding at offset zero means the indexing - // shifted to the negative (or positive) direction. + + // When index_map maps a reference tensor, find the corresponding + // reference ID of window_id. + if (use_reference_map) { + auto concrete_window_id = + gpu_lower->caIndexMap().getConcreteMappedID(window_id); + auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id); + TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end()); + window_id = concrete_2_ref_it->second; + } + + auto window_idx = index_map.at(window_id); + + // Positive padding at offset zero means the indexing shifted to the + // negative direction. auto pad_width = gather_expr->padWidth()[consumer_root_axis][0]; // producer offset: window_index - padding - auto producer_offset = - ir_builder.subExpr(window_idx, ir_builder.create(pad_width)); + auto producer_offset = SimplifyingIrBuilder::subExpr( + window_idx, IrBuilder::create(pad_width)); return producer_offset; - ; } //! Offset a producer index of a gather expression @@ -428,13 +405,13 @@ kir::Val* getProducerOffsetWithGather( //! expression that accesses a window position that the current loop //! structure refers to. Use getGatherProducerOffset to create an //! offset Val. -kir::Val* getProducerIndexWithGather( - kir::Val* producer_index, +Val* getProducerIndexWithGather( + Val* producer_index, size_t producer_root_axis, const TensorView* producer_tv, const TensorView* consumer_tv, const std::unordered_map& concrete_to_ref_map, - const std::unordered_map& ref_index_map) { + const std::unordered_map& ref_index_map) { auto gather_op = dynamic_cast(consumer_tv->definition()); // Just return the producer index as is if this is not a gather @@ -460,22 +437,18 @@ kir::Val* getProducerIndexWithGather( ", producer_axis: ", producer_root_axis); - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); auto offset = getProducerOffsetWithGather( - consumer_axis, consumer_tv, concrete_to_ref_map, ref_index_map); - return ir_builder.addExpr(producer_index, offset); + consumer_axis, consumer_tv, ref_index_map, true, concrete_to_ref_map); + return SimplifyingIrBuilder::addExpr(producer_index, offset); } // Adjusts a global consumer index when its root domain is partially // split. Note that non-global consumer indices don't need any // adjustment. -kir::Val* getGlobalConsumerOffsetWithPartialSplit(kir::IterDomain* root_id) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto offset = gpu_lower->partialSplitMap().getStartOffset(root_id); +Val* getGlobalConsumerOffsetWithPartialSplit(IterDomain* root_id) { + auto offset = GpuLower::current()->partialSplitMap().getStartOffset(root_id); if (offset == nullptr) { - return ir_builder.zeroVal(); + return GpuLower::current()->kernel()->zeroVal(); } else { return offset; } @@ -488,13 +461,12 @@ kir::Val* getGlobalConsumerOffsetWithPartialSplit(kir::IterDomain* root_id) { // it needs to be added to the index. Also, when the producer itself // also has a non-zero split offset, that needs to be subtracted from // the index. -kir::Val* getProducerIndexWithPartialSplit( - kir::Val* producer_index, +Val* getProducerIndexWithPartialSplit( + Val* producer_index, IterDomain* producer_root_id, const TensorView* producer_tv, const TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); auto p2c = PairwiseRootDomainMap(producer_tv, consumer_tv) @@ -509,31 +481,29 @@ kir::Val* getProducerIndexWithPartialSplit( auto consumer_offset = gpu_lower->partialSplitMap().getStartOffset(consumer_root_id); - auto consumer_offset_kir = consumer_offset == nullptr - ? ir_builder.zeroVal() - : gpu_lower->lowerValue(consumer_offset); + consumer_offset = consumer_offset == nullptr ? gpu_lower->kernel()->zeroVal() + : consumer_offset; auto producer_offset = gpu_lower->partialSplitMap().getStartOffset(producer_root_id); - auto producer_offset_kir = producer_offset == nullptr - ? ir_builder.zeroVal() - : gpu_lower->lowerValue(producer_offset); + producer_offset = producer_offset == nullptr ? gpu_lower->kernel()->zeroVal() + : producer_offset; // If the producer is on global memory, it's always allocated // without trimming the out-of-bounds region, so the consumer offset // should be added to the index. if (producer_tv->getMemoryType() == MemoryType::Global) { - if (consumer_offset_kir->isZeroInt()) { + if (consumer_offset->isZeroInt()) { return producer_index; } else { - return ir_builder.addExpr(producer_index, consumer_offset_kir); + return IrBuilder::addExpr(producer_index, consumer_offset); } } // Non-global case. Difference of the split offsets must be // accounted. - auto diff = ir_builder.subExpr(consumer_offset_kir, producer_offset_kir); + auto diff = IrBuilder::subExpr(consumer_offset, producer_offset); kir::ExpressionEvaluator ee; auto diff_eval = ee.evaluate(diff); // We currently only allow constant offsetting @@ -543,19 +513,16 @@ kir::Val* getProducerIndexWithPartialSplit( return producer_index; } - return ir_builder.addExpr( - producer_index, ir_builder.create(diff_eval.value())); + return IrBuilder::addExpr( + producer_index, IrBuilder::create(diff_eval.value())); } } // namespace void IndexCompute::handle(Split* split) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto in_id = gpu_lower->lowerValue(split->in())->as(); - auto outer_id = gpu_lower->lowerValue(split->outer())->as(); - auto inner_id = gpu_lower->lowerValue(split->inner())->as(); + auto in_id = split->in()->as(); + auto outer_id = split->outer()->as(); + auto inner_id = split->inner()->as(); auto outer_it = index_map_.find(outer_id); auto inner_it = index_map_.find(inner_id); @@ -588,8 +555,8 @@ void IndexCompute::handle(Split* split) { } if (isZero(in_id)) { - index_map_[in_id] = ir_builder.create(0); - extent_map_[in_id] = ir_builder.create(0); + index_map_[in_id] = GpuLower::current()->kernel()->zeroVal(); + extent_map_[in_id] = GpuLower::current()->kernel()->zeroVal(); } else if (zero_merged_in && outer_zero) { index_map_[in_id] = inner_ind; extent_map_[in_id] = getExtent(inner_id); @@ -597,24 +564,21 @@ void IndexCompute::handle(Split* split) { index_map_[in_id] = outer_ind; extent_map_[in_id] = getExtent(outer_id); } else { - index_map_[in_id] = ir_builder.addExpr( - ir_builder.mulExpr(outer_ind, getExtent(inner_id)), inner_ind); + index_map_[in_id] = IrBuilder::addExpr( + IrBuilder::mulExpr(outer_ind, getExtent(inner_id)), inner_ind); // The extent should be updated only when its allocation is // partial, i.e., zero_merged_in is true. See PR #1270. if (zero_merged_in) { extent_map_[in_id] = - ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id)); + IrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id)); } } } void IndexCompute::handle(Merge* merge) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto out_id = gpu_lower->lowerValue(merge->out())->as(); - auto outer_id = gpu_lower->lowerValue(merge->outer())->as(); - auto inner_id = gpu_lower->lowerValue(merge->inner())->as(); + auto out_id = merge->out(); + auto outer_id = merge->outer(); + auto inner_id = merge->inner(); auto out_it = index_map_.find(out_id); if (out_it == index_map_.end()) { @@ -622,7 +586,7 @@ void IndexCompute::handle(Merge* merge) { } auto out_ind = out_it->second; - auto zero = ir_builder.zeroVal(); + auto zero = GpuLower::current()->kernel()->zeroVal(); if (isZero(out_id)) { index_map_[outer_id] = zero; @@ -643,17 +607,14 @@ void IndexCompute::handle(Merge* merge) { TORCH_INTERNAL_ASSERT(!input_ids.empty()); for (auto root_id : input_ids) { - index_map_[gpu_lower->lowerValue(root_id)->as()] = zero; + index_map_[root_id] = zero; } - index_map_[gpu_lower - ->lowerValue(*(input_ids.end() - 1)) - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ->as()] = out_ind; + index_map_[*(input_ids.end() - 1)] = out_ind; return; } - kir::Val* inner_extent = getExtent(inner_id); + Val* inner_extent = getExtent(inner_id); // When the reference has halo extent for inner_id, that extent needs to // be used to un-merge @@ -718,8 +679,8 @@ void IndexCompute::handle(Merge* merge) { zero_merged_in_.emplace(inner_id); zero_merged_in_.emplace(outer_id); } else { - index_map_[outer_id] = ir_builder.divExpr(out_ind, inner_extent); - index_map_[inner_id] = ir_builder.modExpr(out_ind, inner_extent); + index_map_[outer_id] = IrBuilder::divExpr(out_ind, inner_extent); + index_map_[inner_id] = IrBuilder::modExpr(out_ind, inner_extent); } } @@ -739,13 +700,13 @@ void IndexCompute::handle(Expr* e) { // using TransformIter::runBackward; IndexCompute::IndexCompute( const TensorDomain* _td, - std::unordered_map initial_index_map, - std::unordered_map extent_map, - std::unordered_set zero_domains, - std::unordered_set zero_merged_in, + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in, const std::vector& root_contiguity, - std::unordered_set preferred_paths, - std::unordered_map reference_halo_extent_map) + std::unordered_set preferred_paths, + std::unordered_map reference_halo_extent_map) : td_(_td), index_map_(std::move(initial_index_map)), extent_map_(std::move(extent_map)), @@ -783,7 +744,7 @@ void IndexCompute::run() { traverseFrom(td_->fusion(), domain_vals, false); } -kir::Val* IndexCompute::getExtent(kir::IterDomain* id) { +Val* IndexCompute::getExtent(IterDomain* id) { // Pick from extent_map_ if available. Previously parallel // dimensions were ued (e.g., blockDim.x), however, it would result // in out-of-bounds errors when the extent of IterDomain is smaller @@ -795,11 +756,11 @@ kir::Val* IndexCompute::getExtent(kir::IterDomain* id) { } } -bool IndexCompute::hasZeroMerged(kir::IterDomain* id) const { +bool IndexCompute::hasZeroMerged(IterDomain* id) const { return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id); } -bool IndexCompute::isZero(kir::IterDomain* id) const { +bool IndexCompute::isZero(IterDomain* id) const { return zero_domains_.find(id) != zero_domains_.end(); } @@ -807,22 +768,17 @@ IndexCompute IndexCompute::updateIndexCompute( const TensorDomain* new_td, const std::unordered_map& id_map, const std::vector& root_contiguity, - const std::unordered_map& - reference_halo_extent_map) { + const std::unordered_map& reference_halo_extent_map) { FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute"); - const auto gpu_lower = GpuLower::current(); - - std::unordered_map updated_index_map; - std::unordered_map updated_extent_map; - std::unordered_set updated_zero_domains; - std::unordered_set updated_zero_merged_in; + std::unordered_map updated_index_map; + std::unordered_map updated_extent_map; + std::unordered_set updated_zero_domains; + std::unordered_set updated_zero_merged_in; for (auto id_entry : id_map) { - kir::IterDomain* prev_id = - gpu_lower->lowerValue(id_entry.first)->as(); - kir::IterDomain* new_id = - gpu_lower->lowerValue(id_entry.second)->as(); + IterDomain* prev_id = id_entry.first; + IterDomain* new_id = id_entry.second; if (index_map_.find(prev_id) != index_map_.end()) { updated_index_map[new_id] = index_map_.at(prev_id); @@ -859,8 +815,8 @@ class UpdateLeafIndices : public IterVisitor { public: UpdateLeafIndices( const TensorDomain* td, - std::unordered_map initial_index_map, - std::unordered_map extent_map) + std::unordered_map initial_index_map, + std::unordered_map extent_map) : td_(td), index_map_(std::move(initial_index_map)), extent_map_(std::move(extent_map)) { @@ -870,11 +826,11 @@ class UpdateLeafIndices : public IterVisitor { traverseFrom(td_->fusion(), domain_vals, false); } - const std::unordered_map& indexMap() const { + const std::unordered_map& indexMap() const { return index_map_; } - const std::unordered_map& extentMap() const { + const std::unordered_map& extentMap() const { return extent_map_; } @@ -882,13 +838,9 @@ class UpdateLeafIndices : public IterVisitor { using IterVisitor::handle; void handle(Split* split) override { - const auto gpu_lower = GpuLower::current(); - - auto in_id = gpu_lower->lowerValue(split->in())->as(); - auto outer_id = - gpu_lower->lowerValue(split->outer())->as(); - auto inner_id = - gpu_lower->lowerValue(split->inner())->as(); + auto in_id = split->in(); + auto outer_id = split->outer(); + auto inner_id = split->inner(); // Nothing need to be done when mappings for the output axes // already exist. @@ -899,22 +851,17 @@ class UpdateLeafIndices : public IterVisitor { return; } - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto factor = gpu_lower->lowerValue(split->factor()); - index_map_[inner_id] = ir_builder.modExpr(index_map_[in_id], factor); + auto factor = split->factor(); + index_map_[inner_id] = IrBuilder::modExpr(index_map_[in_id], factor); extent_map_[inner_id] = factor; - index_map_[outer_id] = ir_builder.divExpr(index_map_[in_id], factor); - extent_map_[outer_id] = ir_builder.ceilDivExpr(getExtent(in_id), factor); + index_map_[outer_id] = IrBuilder::divExpr(index_map_[in_id], factor); + extent_map_[outer_id] = IrBuilder::ceilDivExpr(getExtent(in_id), factor); } void handle(Merge* merge) override { - const auto gpu_lower = GpuLower::current(); - - auto out_id = gpu_lower->lowerValue(merge->out())->as(); - auto outer_id = - gpu_lower->lowerValue(merge->outer())->as(); - auto inner_id = - gpu_lower->lowerValue(merge->inner())->as(); + auto out_id = merge->out(); + auto outer_id = merge->outer(); + auto inner_id = merge->inner(); // Nothing need to be done when mappings for the output axes // already exist. @@ -927,17 +874,16 @@ class UpdateLeafIndices : public IterVisitor { TORCH_INTERNAL_ASSERT( index_map_.find(inner_id) != index_map_.end(), "Inner ID not found"); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - index_map_[out_id] = ir_builder.mulExpr( + index_map_[out_id] = IrBuilder::mulExpr( index_map_[inner_id], - ir_builder.mulExpr(index_map_[outer_id], getExtent(inner_id))); + IrBuilder::mulExpr(index_map_[outer_id], getExtent(inner_id))); extent_map_[out_id] = - ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id)); + IrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id)); } // return extent_map_[id] if exists, else return id->extent() - kir::Val* getExtent(kir::IterDomain* id) { + Val* getExtent(IterDomain* id) { if (extent_map_.find(id) != extent_map_.end()) { return extent_map_.at(id); } else { @@ -947,25 +893,21 @@ class UpdateLeafIndices : public IterVisitor { private: const TensorDomain* td_; - std::unordered_map index_map_; - std::unordered_map extent_map_; + std::unordered_map index_map_; + std::unordered_map extent_map_; }; // Returns halo-extended extent if id has halo. Otherwise, just // returns id->extent. -kir::Val* getHaloExtentOfRootAxis( - IterDomain* id, - kir::Val* normal_extent = nullptr) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - +Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) { if (normal_extent == nullptr) { - normal_extent = gpu_lower->lowerValue(id->extent()); + normal_extent = id->extent(); } - const auto& halo = gpu_lower->haloInfo().getRootAxisInfo(id); + const auto& halo = GpuLower::current()->haloInfo().getRootAxisInfo(id); if (halo.hasHalo()) { - auto halo_extent = ir_builder.addExpr(normal_extent, halo.width()); + auto halo_extent = + IrBuilder::addExpr(normal_extent, IrBuilder::create(halo.width())); return halo_extent; } else { return normal_extent; @@ -976,10 +918,10 @@ kir::Val* getHaloExtentOfRootAxis( IndexSwizzle::IndexSwizzle( const TensorView* tv, - std::unordered_map initial_index_map, - std::unordered_map extent_map, - std::unordered_set zero_domains, - std::unordered_set zero_merged_in) + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in) : IndexCompute( tv->domain(), std::move(initial_index_map), @@ -996,8 +938,6 @@ void IndexSwizzle::run() { swizzle_type_ == SwizzleType::NoSwizzle || swizzle_type_ == SwizzleType::Transpose, "Invalid swizzle type"); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); if (swizzle_type_ == SwizzleType::Transpose) { // Shifts the second axis by the first axis as ((idx_1 + idx_2) % // ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also @@ -1013,20 +953,16 @@ void IndexSwizzle::run() { IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0); IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1); - kir::IterDomain* id_to_swizzle_i_kir = - gpu_lower->lowerValue(id_to_swizzle_i)->as(); - kir::IterDomain* id_to_swizzle_j_kir = - gpu_lower->lowerValue(id_to_swizzle_j)->as(); - - if (indexMap().find(id_to_swizzle_i_kir) != indexMap().end() && - indexMap().find(id_to_swizzle_j_kir) != indexMap().end()) { - auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i_kir); - auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j_kir); - - auto swizzled_idx = ir_builder.modExpr( - ir_builder.addExpr(idx_to_swizzle_i, idx_to_swizzle_j), - id_to_swizzle_j_kir->extent()); - index_map_[id_to_swizzle_j_kir] = swizzled_idx; + + if (indexMap().find(id_to_swizzle_i) != indexMap().end() && + indexMap().find(id_to_swizzle_j) != indexMap().end()) { + auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i); + auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j); + + auto swizzled_idx = IrBuilder::modExpr( + IrBuilder::addExpr(idx_to_swizzle_i, idx_to_swizzle_j), + id_to_swizzle_j->extent()); + index_map_[id_to_swizzle_j] = swizzled_idx; swizzled_ids_.insert(id_to_swizzle_j); IndexCompute::run(); } @@ -1055,17 +991,15 @@ namespace { // to loop indices as well as a set of loops that do not contribute to // indexing. std::pair< - std::unordered_map, + std::unordered_map, std::unordered_set> indexMapFromTV( const TensorView* tv, const std::vector& loops, - const std::pair& alloc_point, - bool as_consumer) { + kir::ForLoop* alloc_loop, + bool as_consumer, + kir::ForLoop* double_buffer_loop = nullptr) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto alloc_loop = alloc_point.first; bool within_alloc = false; if (alloc_loop == nullptr) { @@ -1076,7 +1010,7 @@ indexMapFromTV( const bool is_shared = tv->getMemoryType() == MemoryType::Shared; const bool is_local = tv->getMemoryType() == MemoryType::Local; - std::unordered_map loop_to_ind_map; + std::unordered_map loop_to_ind_map; // When indexed as a producer, the parallel types of the the // producer domains may not be the same as those of the loops, but @@ -1085,17 +1019,16 @@ indexMapFromTV( // with zero isn't valid. That's only valid when there's a matching // IterDomain in the producer tensor that has the same parallel // type. - auto find_matching_parallel_domain = [tv](kir::IterDomain* id) -> bool { + auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool { const auto gpu_lower = GpuLower::current(); auto it = std::find_if( tv->domain()->domain().begin(), tv->domain()->domain().end(), [&](IterDomain* tv_id) { - auto kir_tv_id = gpu_lower->lowerValue(tv_id)->as(); // Matching is done using the index and loop maps. See // validateParallelize as well. - return gpu_lower->caIndexMap().areMapped(id, kir_tv_id) || - (gpu_lower->caLoopMap().areMapped(id, kir_tv_id) && + return gpu_lower->caIndexMap().areMapped(id, tv_id) || + (gpu_lower->caLoopMap().areMapped(id, tv_id) && ir_utils::derivedFromRootCAAxes(tv, tv_id)); }); if (it == tv->domain()->domain().end()) { @@ -1103,7 +1036,7 @@ indexMapFromTV( } auto corresponding_domain = *it; - return corresponding_domain->getParallelType() == id->parallelType(); + return corresponding_domain->getParallelType() == id->getParallelType(); }; // Track domains that do not contibute to the resulting @@ -1113,7 +1046,7 @@ indexMapFromTV( std::unordered_set zero_loops; for (auto loop : loops) { - kir::Val* idx = nullptr; + Val* idx = nullptr; const auto same_parallel_type = as_consumer || find_matching_parallel_domain(loop->iter_domain()); // See also LoopNestGenerator::pushAlloc. @@ -1123,7 +1056,7 @@ indexMapFromTV( (loop->iter_domain()->isThread() && is_global)) { idx = loop->index(); } else { - idx = ir_builder.zeroVal(); + idx = GpuLower::current()->kernel()->zeroVal(); zero_loops.insert(loop); } } else if ( @@ -1145,7 +1078,7 @@ indexMapFromTV( // parallel type (loop->iter_domain()->isThread() && is_local && same_parallel_type) || loop->vectorize()) { - idx = ir_builder.zeroVal(); + idx = GpuLower::current()->kernel()->zeroVal(); if (!loop->vectorize()) { zero_loops.insert(loop); } @@ -1153,6 +1086,10 @@ indexMapFromTV( idx = loop->index(); } + if (loop == double_buffer_loop) { + idx = IrBuilder::addExpr(idx, GpuLower::current()->kernel()->oneVal()); + } + loop_to_ind_map[loop] = idx; if (!within_alloc && loop == alloc_loop) { @@ -1184,8 +1121,6 @@ void ensureStaticIndexing( within_alloc = true; } - const auto gpu_lower = GpuLower::current(); - for (auto loop : loops) { if (!within_alloc) { if (loop == alloc_loop) { @@ -1193,7 +1128,7 @@ void ensureStaticIndexing( } continue; } - kir::IterDomain* loop_id = loop->iter_domain(); + IterDomain* loop_id = loop->iter_domain(); if (loop->vectorize() || loop_id->isThread()) { continue; } @@ -1203,7 +1138,7 @@ void ensureStaticIndexing( auto it = std::find_if( tv->domain()->domain().begin(), tv->domain()->domain().end(), - [loop_id, gpu_lower, &id_map](IterDomain* id) { + [loop_id, &id_map](IterDomain* id) { if (id->isBroadcast() || id->isReduction() || id->isStride()) { return false; } @@ -1211,8 +1146,7 @@ void ensureStaticIndexing( if (id_replacement != id_map.end()) { id = id_replacement->second; } - auto kir_id = gpu_lower->lowerValue(id)->as(); - return gpu_lower->caLoopMap().areMapped(loop_id, kir_id); + return GpuLower::current()->caLoopMap().areMapped(loop_id, id); }); if (it != tv->domain()->domain().end()) { loop->requireUnroll(); @@ -1260,13 +1194,12 @@ std::unordered_map indexMapReferenceTo( } // namespace -std::vector Index::getGlobalProducerStridedIndices( +std::vector Index::getGlobalProducerStridedIndices( TensorView* producer_tv, const TensorView* consumer_tv, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex"); const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); @@ -1311,9 +1244,12 @@ std::vector Index::getGlobalProducerStridedIndices( } } + kir::ForLoop* db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + // Index into the reference tensor. Reference indexing will handle vectorized // dims where index should be set to 0 - auto ref_compute = getReferenceIndexing(loops, reference_domain); + auto ref_compute = getReferenceIndexing(loops, reference_domain, db_loop); // Forward vectorized IDs to index into producer correctly // We want p_id to be vectorized like consumer just for the indexing, then we @@ -1355,25 +1291,24 @@ std::vector Index::getGlobalProducerStridedIndices( auto root_dom = producer_tv->getMaybeRFactorDomain(); // TODO: Abstract stride logic to reuse with consumer indexing - auto zero = ir_builder.create(0); - std::vector strides(root_dom.size(), nullptr); + std::vector strides(root_dom.size(), nullptr); { int stride_i = 0; for (const auto i : c10::irange(root_dom.size())) { if (root_dom[i]->isReduction() || root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) { - strides[i] = zero; + strides[i] = GpuLower::current()->kernel()->oneVal(); continue; } std::stringstream ss; ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]"; - strides[i] = ir_builder.create(ss.str(), DataType::Int); + strides[i] = IrBuilder::create(ss.str(), DataType::Int); } } TORCH_INTERNAL_ASSERT( root_dom.size() == producer_tv->domain()->contiguity().size()); - kir::Val* cur_contig_stride = ir_builder.create(1); + Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; if (root_dom[dim]->isReduction()) { @@ -1383,14 +1318,12 @@ std::vector Index::getGlobalProducerStridedIndices( continue; } - kir::Val* root_ind = nullptr; - auto kir_root_dom = - gpu_lower->lowerValue(root_dom[dim])->as(); - if (producer_indexing.indexMap().find(kir_root_dom) != + Val* root_ind = nullptr; + if (producer_indexing.indexMap().find(root_dom[dim]) != producer_indexing.indexMap().end()) { - root_ind = producer_indexing.indexMap().at(kir_root_dom); + root_ind = producer_indexing.indexMap().at(root_dom[dim]); } else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) { - root_ind = zero; + root_ind = GpuLower::current()->kernel()->zeroVal(); } TORCH_INTERNAL_ASSERT( @@ -1410,12 +1343,12 @@ std::vector Index::getGlobalProducerStridedIndices( // by extent of this dimension auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); cur_contig_stride = - ir_builder.mulExpr(cur_contig_stride, root_dim_extent); + IrBuilder::mulExpr(cur_contig_stride, root_dim_extent); } else { // If non contiguous dimension, keep local stride information, set cur // stride to local stride * local raw extent auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); - cur_contig_stride = ir_builder.mulExpr(strides[dim], root_dim_extent); + cur_contig_stride = IrBuilder::mulExpr(strides[dim], root_dim_extent); } } @@ -1423,7 +1356,8 @@ std::vector Index::getGlobalProducerStridedIndices( loops.empty() ? nullptr : loops.back()->vectorize_shift(); // Global striding - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // If the domain is derived from a trivial reduction, no indexing // to create. @@ -1434,20 +1368,17 @@ std::vector Index::getGlobalProducerStridedIndices( continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - producer_indexing.indexMap().find(kir_root_dom_i) != + producer_indexing.indexMap().find(root_dom[i]) != producer_indexing.indexMap().end(), "Couldn't find root mapping for TV", producer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - auto root_ind = producer_indexing.indexMap().at(kir_root_dom_i); + auto root_ind = producer_indexing.indexMap().at(root_dom[i]); root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv); @@ -1465,9 +1396,9 @@ std::vector Index::getGlobalProducerStridedIndices( if (root_ind->isZeroInt()) { continue; } else { - auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]); + auto strided_ind = IrBuilder::mulExpr(root_ind, strides[i]); if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { - strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift); + strided_inds[i] = IrBuilder::addExpr(strided_ind, vectorize_shift); } else { strided_inds[i] = strided_ind; } @@ -1478,12 +1409,11 @@ std::vector Index::getGlobalProducerStridedIndices( } // Producer index for either shared or local memory -std::vector Index::getNonGlobalProducerStridedIndices( +std::vector Index::getNonGlobalProducerStridedIndices( TensorView* producer_tv, const TensorView* consumer_tv, const std::vector& loops) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); @@ -1526,31 +1456,35 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } + kir::ForLoop* consumer_db_loop = + gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + // Find allocation point of producer relative to loop nests. P2C map is // required because producer was replayed as consumer, so we can't use the // regular compute at maps to line up its iter domains with the for loops. - auto alloc_point = - loop_utils::getAllocPoint(producer_tv, loops, p2c_alloc_map, true); - std::unordered_map loop_to_ind_map; + auto alloc_info = + loop_utils::getAllocInformation(producer_tv, loops, p2c_alloc_map, true); + std::unordered_map loop_to_ind_map; std::unordered_set zero_loops; - std::tie(loop_to_ind_map, zero_loops) = - indexMapFromTV(producer_tv, loops, alloc_point, false); + std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( + producer_tv, loops, alloc_info.init_for_loop, false, consumer_db_loop); - ensureStaticIndexing(producer_tv, alloc_point.first, loops, p2c_alloc_map); + ensureStaticIndexing( + producer_tv, alloc_info.init_for_loop, loops, p2c_alloc_map); // Map loop nests to indicies, zeroing out those not used due to locality of // memory - std::unordered_map ref_id_to_ind_map; + std::unordered_map ref_id_to_ind_map; // Track which domains are not used - std::unordered_set ref_zero_domains; + std::unordered_set ref_zero_domains; // Due to rfactor/initialization reference_domain may be bigger than loop nest // structure, ignore IterDomains that aren't present in the loop nest when // indexing reference. TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); for (const auto loop_i : c10::irange(loops.size())) { - auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i)) - ->as(); + auto ref_axis = reference_domain->axis(loop_i); ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]]; if (zero_loops.count(loops[loop_i]) > 0) { ref_zero_domains.insert(ref_axis); @@ -1677,8 +1611,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( } // Already an entry for this root domain, continue - if (index_map.find(gpu_lower->lowerValue(root_id)->as()) != - index_map.end()) { + if (index_map.find(root_id) != index_map.end()) { continue; } @@ -1690,25 +1623,23 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { if (skip_indexing.count(root_dom[i])) { continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_i) != index_map.end(), + index_map.find(root_dom[i]) != index_map.end(), "Couldn't find root mapping for TV", producer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - auto root_ind_i = index_map.at(kir_root_dom_i); + auto root_ind_i = index_map.at(root_dom[i]); root_ind_i = getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv); @@ -1729,17 +1660,14 @@ std::vector Index::getNonGlobalProducerStridedIndices( } // Compute striding for this index. - kir::Val* stride = nullptr; + Val* stride = nullptr; for (const auto j : c10::irange(i + 1, root_dom.size())) { if (skip_indexing.count(root_dom[j])) { continue; } - auto kir_root_dom_j = - gpu_lower->lowerValue(root_dom[j])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_j) != index_map.end(), + index_map.find(root_dom[j]) != index_map.end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", @@ -1747,37 +1675,49 @@ std::vector Index::getNonGlobalProducerStridedIndices( " id: ", root_dom[i]); - auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end() - ? kir_root_dom_j->extent() - : extent_map.at(kir_root_dom_j); + auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() + ? root_dom[j]->extent() + : extent_map.at(root_dom[j]); root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j); - if (zero_domain_map.count(kir_root_dom_j) == 0) { + if (zero_domain_map.count(root_dom[j]) == 0) { if (stride == nullptr) { stride = root_ext_j; } else { - stride = ir_builder.mulExpr(stride, root_ext_j); + stride = IrBuilder::mulExpr(stride, root_ext_j); } } } if (stride != nullptr) { - strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride); + strided_inds[i] = IrBuilder::mulExpr(root_ind_i, stride); } else { strided_inds[i] = root_ind_i; } } + if (producer_tv->isDoubleBuffered()) { + auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + producer_tv, loops, true); + if (db_loop != nullptr) { + auto db_switch_index = + IrBuilder::modExpr(db_loop->index(), IrBuilder::create(2)); + auto original_alloc_size = + gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); + auto db_strided_index = + IrBuilder::mulExpr(db_switch_index, original_alloc_size); + strided_inds.push_back(db_strided_index); + } + } return strided_inds; } -std::vector Index::getGlobalConsumerStridedIndices( +std::vector Index::getGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); @@ -1813,26 +1753,27 @@ std::vector Index::getGlobalConsumerStridedIndices( auto root_dom = consumer_tv->getMaybeRFactorDomain(); // TODO: Abstract stride logic to reuse with producer indexing - auto zero = ir_builder.zeroVal(); - std::vector strides(root_dom.size(), zero); + std::vector strides( + root_dom.size(), GpuLower::current()->kernel()->oneVal()); { int stride_i = 0; for (const auto i : c10::irange(root_dom.size())) { if (root_dom[i]->isReduction() || root_dom[i]->getIterType() == IterType::BroadcastWithoutStride || root_dom[i]->isStride()) { - strides[i] = zero; + strides[i] = GpuLower::current()->kernel()->oneVal(); continue; } std::stringstream ss; ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]"; - strides[i] = ir_builder.create(ss.str(), DataType::Int); + strides[i] = + SimplifyingIrBuilder::create(ss.str(), DataType::Int); } } TORCH_INTERNAL_ASSERT( root_dom.size() == consumer_tv->domain()->contiguity().size()); - kir::Val* cur_contig_stride = ir_builder.oneVal(); + Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) { @@ -1842,14 +1783,12 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } - kir::Val* root_ind = nullptr; - auto kir_root_dom = - gpu_lower->lowerValue(root_dom[dim])->as(); - if (consumer_indexing.indexMap().find(kir_root_dom) != + Val* root_ind = nullptr; + if (consumer_indexing.indexMap().find(root_dom[dim]) != consumer_indexing.indexMap().end()) { - root_ind = consumer_indexing.indexMap().at(kir_root_dom); + root_ind = consumer_indexing.indexMap().at(root_dom[dim]); } else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) { - root_ind = zero; + root_ind = GpuLower::current()->kernel()->zeroVal(); } TORCH_INTERNAL_ASSERT( @@ -1869,11 +1808,11 @@ std::vector Index::getGlobalConsumerStridedIndices( // by extent of this dimension auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]); cur_contig_stride = - ir_builder.mulExpr(cur_contig_stride, root_dim_extent); + SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent); } else { // If non contiguous dimension, keep local stride information, set cur // stride to local stride * local raw extent - cur_contig_stride = ir_builder.mulExpr( + cur_contig_stride = SimplifyingIrBuilder::mulExpr( strides[dim], getHaloExtentOfRootAxis(root_dom[dim])); } } @@ -1882,7 +1821,8 @@ std::vector Index::getGlobalConsumerStridedIndices( loops.empty() ? nullptr : loops.back()->vectorize_shift(); // Global striding - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // See a comment in indexing to root domains in getGlobalProducerIndex. if (root_dom[i]->isReduction() || @@ -1893,71 +1833,70 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - consumer_indexing.indexMap().find(kir_root_dom_i) != + consumer_indexing.indexMap().find(root_dom[i]) != consumer_indexing.indexMap().end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - auto root_ind = consumer_indexing.indexMap().at(kir_root_dom_i); + auto root_ind = consumer_indexing.indexMap().at(root_dom[i]); - root_ind = ir_builder.addExpr( - root_ind, getGlobalConsumerOffsetWithPartialSplit(kir_root_dom_i)); + root_ind = SimplifyingIrBuilder::addExpr( + root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i])); if (root_ind->isZeroInt()) { continue; } else { - auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]); + auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { - strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift); + strided_inds[i] = + SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); } else { strided_inds[i] = strided_ind; } } } + TORCH_INTERNAL_ASSERT( + strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); + return strided_inds; } // Consumer index for either shared or local memory -std::vector Index::getNonGlobalConsumerStridedIndices( +std::vector Index::getNonGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // Get a reference tensor replayed as existing loop structure auto reference = IndexReferenceReplay::getReference(loops); auto reference_domain = reference.domain; auto reference_id_map = reference.concrete_to_id; - auto alloc_point = loop_utils::getAllocPoint(consumer_tv, loops); - std::unordered_map loop_to_ind_map; + auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops); + std::unordered_map loop_to_ind_map; std::unordered_set zero_loops; std::tie(loop_to_ind_map, zero_loops) = - indexMapFromTV(consumer_tv, loops, alloc_point, true); + indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true); - ensureStaticIndexing(consumer_tv, alloc_point.first, loops); + ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops); // Map loop nests to indicies, zeroing out those not used due to locality of // memory - std::unordered_map ref_id_to_ind_map; - std::unordered_set ref_zero_domains; + std::unordered_map ref_id_to_ind_map; + std::unordered_set ref_zero_domains; // Due to rfactor/initialization reference_domain may be bigger than loop nest // structure, ignore IterDomains that aren't present in the loop nest when // indexing reference. TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); for (const auto loop_i : c10::irange(loops.size())) { - auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i)) - ->as(); + auto ref_axis = reference_domain->axis(loop_i); ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]]; if (zero_loops.count(loops[loop_i]) > 0) { ref_zero_domains.insert(ref_axis); @@ -2022,7 +1961,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( // Indices should now be mapped onto IterDomains in consumer, so just grab // and use them. auto root_dom = consumer_tv->getMaybeRFactorDomain(); - std::vector strided_inds(root_dom.size(), ir_builder.zeroVal()); + std::vector strided_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() || gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) || @@ -2030,25 +1970,22 @@ std::vector Index::getNonGlobalConsumerStridedIndices( continue; } - auto kir_root_dom_i = - gpu_lower->lowerValue(root_dom[i])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_i) != index_map.end(), + index_map.find(root_dom[i]) != index_map.end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", i, " id: ", - kir::toString(kir_root_dom_i)); + root_dom[i]->toString()); - const auto root_ind_i = index_map.at(kir_root_dom_i); + const auto root_ind_i = index_map.at(root_dom[i]); if (root_ind_i->isZeroInt()) { continue; } // Compute striding for this index. - kir::Val* stride = nullptr; + Val* stride = nullptr; for (const auto j : c10::irange(i + 1, root_dom.size())) { if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() || gpu_lower->trivialReductionInfo().isDerived(root_dom[j]) || @@ -2056,11 +1993,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( continue; } - auto kir_root_dom_j = - gpu_lower->lowerValue(root_dom[j])->as(); - TORCH_INTERNAL_ASSERT( - index_map.find(kir_root_dom_j) != index_map.end(), + index_map.find(root_dom[j]) != index_map.end(), "Couldn't find root mapping for TV", consumer_tv->name(), " dim: ", @@ -2068,45 +2002,67 @@ std::vector Index::getNonGlobalConsumerStridedIndices( " id: ", root_dom[i]); - auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end() - ? kir_root_dom_j->extent() - : extent_map.at(kir_root_dom_j); + auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() + ? root_dom[j]->extent() + : extent_map.at(root_dom[j]); root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j); - if (zero_domain_map.count(kir_root_dom_j) == 0) { + if (zero_domain_map.count(root_dom[j]) == 0) { if (stride == nullptr) { stride = root_ext_j; } else { - stride = ir_builder.mulExpr(stride, root_ext_j); + stride = IrBuilder::mulExpr(stride, root_ext_j); } } } if (stride != nullptr) { - strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride); + strided_inds[i] = IrBuilder::mulExpr(root_ind_i, stride); } else { strided_inds[i] = root_ind_i; } } + // This check was originally done in getConsumerStridedIndices, but + // the number of strided index values depends on the loop where the + // consumer tensor is located. If it's double buffered and not in + // the prologue loop, strided_inds ends up having one more + // index, so it's just much simpler to check here before adding the + // additional index for double buffering. + TORCH_INTERNAL_ASSERT( + strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); + + if (consumer_tv->isDoubleBuffered()) { + auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + if (db_loop != nullptr) { + auto db_switch_index = IrBuilder::subExpr( + gpu_lower->kernel()->oneVal(), + IrBuilder::modExpr(db_loop->index(), IrBuilder::create(2))); + auto original_alloc_size = + gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv); + auto db_strided_index = + IrBuilder::mulExpr(db_switch_index, original_alloc_size); + strided_inds.push_back(db_strided_index); + } + } + return strided_inds; } -std::vector Index::getProducerStridedIndices( +std::vector Index::getProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices"); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - if (producer->domain()->noReductions().size() == 0) { - return std::vector( - producer->getMaybeRFactorDomain().size(), ir_builder.zeroVal()); + return std::vector( + producer->getMaybeRFactorDomain().size(), + GpuLower::current()->kernel()->zeroVal()); } - std::vector strided_indices; + std::vector strided_indices; if (producer->getMemoryType() == MemoryType::Global) { strided_indices = getGlobalProducerStridedIndices(producer, consumer, loops); @@ -2116,7 +2072,9 @@ std::vector Index::getProducerStridedIndices( } TORCH_INTERNAL_ASSERT( - strided_indices.size() == producer->getMaybeRFactorDomain().size()); + strided_indices.size() == + producer->getMaybeRFactorDomain().size() + + (producer->isDoubleBuffered() ? 1 : 0)); return strided_indices; } @@ -2126,35 +2084,27 @@ kir::TensorIndex* Index::getProducerIndex( TensorView* producer, const TensorView* consumer, const std::vector& loops) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto strided_indices = getProducerStridedIndices(producer, consumer, loops); - return ir_builder.create(producer, strided_indices); + return IrBuilder::create(producer, strided_indices); } -std::vector Index::getConsumerStridedIndices( +std::vector Index::getConsumerStridedIndices( const TensorView* consumer, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices"); - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - if (consumer->domain()->noReductions().size() == 0) { - return std::vector( - consumer->getMaybeRFactorDomain().size(), ir_builder.zeroVal()); + return std::vector( + consumer->getMaybeRFactorDomain().size(), + GpuLower::current()->kernel()->zeroVal()); } - std::vector strided_indices; + std::vector strided_indices; if (consumer->getMemoryType() == MemoryType::Global) { strided_indices = getGlobalConsumerStridedIndices(consumer, loops); } else { strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops); } - TORCH_INTERNAL_ASSERT( - strided_indices.size() == consumer->getMaybeRFactorDomain().size()); - return strided_indices; } @@ -2162,11 +2112,8 @@ std::vector Index::getConsumerStridedIndices( kir::TensorIndex* Index::getConsumerIndex( const TensorView* consumer, const std::vector& loops) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto strided_indices = getConsumerStridedIndices(consumer, loops); - return ir_builder.create(consumer, strided_indices); + return IrBuilder::create(consumer, strided_indices); } namespace { @@ -2184,37 +2131,19 @@ struct PredicateDomainInfo { bool is_non_divisible_split = false; }; -// Find iteration domains in the history of reference comprised only of -// merge operations. Only return iteration domains that are subsequently fed -// into a split, or are in the provided domain. In other words, we don't want to -// return every IterDomain that's contiguous, just the one closest to the -// leaves. Predicates are not associated with physical memory so we can treat -// all of them as contiguous merges. +// Find iteration domains in the history of a consumer to predicate comprised +// only of merge operations. Only return iteration domains that are subsequently +// fed into a split, or are in the provided domain. In other words, we don't +// want to return every IterDomain that's contiguous, just the one closest to +// the leaves. Predicates are not associated with physical memory so we can +// treat all of them as contiguous merges. std::vector getPredicateContigIds( - const ReferenceTensor& reference, - TensorView* consumer_tv, - const std::unordered_map& ref_2_consumer) { + TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - std::vector reference_predicated_root_domain; - for (const auto consumer_root : consumer_tv->getRootDomain()) { - if (consumer_root->isBroadcast()) { - continue; - } - auto consumer_root_concrete = - gpu_lower->caIndexMap().getConcreteMappedID(consumer_root); - auto it = reference.concrete_to_id.find(consumer_root_concrete); - // When initializing a reduction buffer, the reduction axis - // doesn't have a loop, so the reference tensor doesn't have a - // mapped domain. The reduction axis can be safely ignored. - if (it == reference.concrete_to_id.end()) { - continue; - } - auto reference_root = it->second; - reference_predicated_root_domain.emplace_back(reference_root); - } + const auto& consumer_root_domain = consumer_tv->getRootDomain(); - std::vector contiguous_ids = reference_predicated_root_domain; + std::vector contiguous_ids = consumer_root_domain; if (contiguous_ids.empty()) { return std::vector(); @@ -2227,20 +2156,24 @@ std::vector getPredicateContigIds( // about halo to do correct predication, so they must be excluded. std::unordered_set excluded_ids; - for (auto reference_predicated_id : reference_predicated_root_domain) { - if (GpuLower::current() - ->haloInfo() - .getRootAxisInfo(reference_predicated_id) - .hasHalo()) { + for (auto consumer_root_id : consumer_root_domain) { + if (gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id).hasHalo()) { + excluded_ids.insert(consumer_root_id); continue; } - auto it = ref_2_consumer.find(reference_predicated_id); - if (it == ref_2_consumer.end()) { + if (consumer_root_id->maybePartial()) { + excluded_ids.insert(consumer_root_id); continue; } - auto consumer_root_id = it->second; - if (consumer_root_id->maybePartial()) { - excluded_ids.insert(reference_predicated_id); + // When consumer_root_id is a broadcast domain, do not allow contig + // predication as the merged output is not mapped with the + // reference unless the concrete domain is also a broadcast + // domain. + if (consumer_root_id->isBroadcast() && + !gpu_lower->caLoopMap() + .getConcreteMappedID(consumer_root_id) + ->isBroadcast()) { + excluded_ids.insert(consumer_root_id); continue; } // Shifted or gathered axes need to be predicated at the root domain @@ -2252,15 +2185,16 @@ std::vector getPredicateContigIds( auto consumer_root_pos = consumer_tv->domain()->rootPosOf(consumer_root_id); if ((shift_expr && shift_expr->offset(consumer_root_pos) != 0) || (gather_expr && consumer_root_pos < gather_expr->windowShape().size() && - !gather_expr->windowShape().at(consumer_root_pos)->isOneInt())) { - excluded_ids.insert(reference_predicated_id); + gather_expr->windowShape().at(consumer_root_pos) != 1)) { + excluded_ids.insert(consumer_root_id); } } // Run through iteration domain history - auto exprs = ExprSort::getExprs( + auto exprs = StmtSort::getExprs( consumer_tv->fusion(), - {reference.domain->domain().begin(), reference.domain->domain().end()}); + {consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end()}); for (auto expr : exprs) { // If not a merge, output is not contiguous @@ -2296,8 +2230,7 @@ std::vector getPredicateContigIds( // reference_predicated_root_domain. auto contig_root_vals = IterVisitor::getInputsTo( {contig_id}, - {reference_predicated_root_domain.begin(), - reference_predicated_root_domain.end()}); + {consumer_root_domain.begin(), consumer_root_domain.end()}); auto contig_root_ids = ir_utils::filterByType(contig_root_vals); PredicateDomainInfo contig_id_info; contig_id_info.id = contig_id; @@ -2312,8 +2245,7 @@ IterDomain* getMappedReferenceDomain( IterDomain* id, const ReferenceTensor& reference) { // Partially overlaps with getPredicateContigIds() - const auto gpu_lower = GpuLower::current(); - auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(id); + auto concrete_id = GpuLower::current()->caIndexMap().getConcreteMappedID(id); auto it = reference.concrete_to_id.find(concrete_id); if (it == reference.concrete_to_id.end()) { return nullptr; @@ -2321,9 +2253,8 @@ IterDomain* getMappedReferenceDomain( return it->second; } -std::vector getNonDivisibleReferenceDomainsToPredicate( - TensorView* consumer_tv, - const ReferenceTensor& reference) { +std::vector getNonDivisibleConsumerDomainsToPredicate( + TensorView* consumer_tv) { const auto& non_divisible_split_info = GpuLower::current()->nonDivisibleSplitInfo(); @@ -2337,11 +2268,7 @@ std::vector getNonDivisibleReferenceDomainsToPredicate( const auto& splits_to_predicate = it->second; for (auto split : splits_to_predicate) { - auto ref_id = getMappedReferenceDomain(split->in(), reference); - if (ref_id == nullptr) { - continue; - } - PredicateDomainInfo info{ref_id, {ref_id}, true}; + PredicateDomainInfo info{split->in(), {split->in()}, true}; pred_info_vec.emplace_back(info); } @@ -2352,9 +2279,8 @@ bool needsPadding(TensorView* tv) { auto shift_expr = dynamic_cast(tv->definition()); auto gather_expr = dynamic_cast(tv->definition()); - // Padding is only necessary for padded shift and - // gather - return (shift_expr != nullptr && shift_expr->pad()) || gather_expr != nullptr; + return (shift_expr != nullptr && shift_expr->hasPadding()) || + (gather_expr != nullptr && gather_expr->hasPadding()); } // Get an additional offset of a stop index when building a predicate @@ -2364,11 +2290,10 @@ bool needsPadding(TensorView* tv) { // compared with each other by just looking at the additional offsets. // // consumer_root_id: the domain for which a stop predicate is being built. -kir::Val* getUnswitchStopOffset( +int getUnswitchStopOffset( IterDomain* consumer_root_id, TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id); @@ -2376,7 +2301,7 @@ kir::Val* getUnswitchStopOffset( // If the consumer root domain to predicate does not have halo, no // adjustment is required. if (!halo_info.hasHalo()) { - return ir_builder.zeroVal(); + return 0; } // Find if this contig_id is used in the unswitched domains @@ -2400,22 +2325,14 @@ kir::Val* getUnswitchStopOffset( })) { return halo_info.width(); } else { - return ir_builder.zeroVal(); + return 0; } } -// Get offsets for the start and stop predicates. Similar to the -// gather case, but it's a little simpler as it does not (yet) -// dynamic shifting. -void adjustStartAndStopOffsetsForShift( - std::vector& start_offsets, - std::vector& stop_offsets, +std::pair getStartAndStopOffsetsForShift( TensorView* consumer_tv, IterDomain* consumer_id, bool padding_predicate) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - TORCH_INTERNAL_ASSERT(consumer_id != nullptr); auto shift_expr = dynamic_cast(consumer_tv->definition()); @@ -2423,105 +2340,124 @@ void adjustStartAndStopOffsetsForShift( // Adjustment is not necessary if not shift. // Even so, padding predicate does not need any adjustment. if (shift_expr == nullptr || padding_predicate) { - return; + return { + GpuLower::current()->kernel()->zeroVal(), + GpuLower::current()->kernel()->zeroVal()}; } const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id); - // Assume this adjustment is done first, so start and stop offsets - // just contain zeroVal. - TORCH_INTERNAL_ASSERT( - start_offsets.size() == 1 && start_offsets[0]->isZeroInt() && - stop_offsets.size() == 1 && stop_offsets[0]->isZeroInt()); - start_offsets.clear(); - stop_offsets.clear(); - - // The consumer offset is zero. - auto consumer_offset = 0; - // The producer offset is based off the consumer offset. - auto producer_offset = 0; - - // When the shift operation is not padded, the start and stop positions of the - // consumer axis, i.e., consumer_id->start and - // consumer_id->stop_ofset, are adjusted accordingly, which includes - // the effect of the shift offset, so using the consumer offset is - // sufficient as the only predicate is sufficient. - - if (shift_expr->pad()) { - // Positive shift offset means shifting the input tensor to the - // positive direction, so the producer offset becomes negative. - auto shift_offset = shift_expr->offset(root_axis_pos); - producer_offset = -shift_offset; - } - - // Since shift doesn't allow dynamic offsets, we can statically - // choose more restrictive offsets between the producer and consumer - // offsets. The start predicate uses greater-than, so using the - // smaller offset is sufficient. Similarly, for the stop predicate, - // using the larger offset is sufficient. - auto start_offset = std::min(consumer_offset, producer_offset); - auto stop_offset = std::max(consumer_offset, producer_offset); - - start_offsets.push_back(ir_builder.create(start_offset)); - stop_offsets.push_back(ir_builder.create(stop_offset)); + // The first or last N elements, where N is the padding width, + // correspond to the padding predicate. + + const auto shift_offset = shift_expr->offset(root_axis_pos); + const auto pad_width = shift_expr->padWidth().at(root_axis_pos); + + int start_offset = 0; + int stop_offset = 0; + + if (shift_offset > 0) { + start_offset = -pad_width; + } else if (shift_offset < 0) { + stop_offset = pad_width; + } + + return { + IrBuilder::create(start_offset), + IrBuilder::create(stop_offset)}; } -// Get offsets for the start and stop predicates. There can be two -// offsets because the shift offset is determined by a loop index. -void adjustStartAndStopOffsetsForGather( - std::vector& start_offsets, - std::vector& stop_offsets, +std::pair getStartAndStopOffsetsForGather( TensorView* consumer_tv, IterDomain* consumer_id, - const ReferenceTensor& reference, - const std::unordered_map& ref_start_index_map, - const std::unordered_map& ref_stop_index_map, + const std::unordered_map& ref_start_index_map, + const std::unordered_map& ref_stop_index_map, bool padding_predicate) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - TORCH_INTERNAL_ASSERT(consumer_id != nullptr); // Adjustment is not necessary if not gather. Even so, padding // predicate does not need any adjustment. if (!consumer_tv->definition()->isA() || padding_predicate) { - return; + return { + GpuLower::current()->kernel()->zeroVal(), + GpuLower::current()->kernel()->zeroVal()}; } const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id); - // Assume this adjustment is done first, so start and stop offsets - // just contain zeroVal. - TORCH_INTERNAL_ASSERT( - start_offsets.size() == 1 && start_offsets[0]->isZeroInt() && - stop_offsets.size() == 1 && stop_offsets[0]->isZeroInt()); - start_offsets.clear(); - stop_offsets.clear(); - auto producer_start_offset = getProducerOffsetWithGather( - root_axis_pos, - consumer_tv, - reference.concrete_to_id, - ref_start_index_map); + root_axis_pos, consumer_tv, ref_start_index_map); auto producer_stop_offset = getProducerOffsetWithGather( - root_axis_pos, consumer_tv, reference.concrete_to_id, ref_stop_index_map); + root_axis_pos, consumer_tv, ref_stop_index_map); - // The producer and consumer accesses must be predicated as it is - // not statically determined which is more restrictive. + auto consumer_start_offset = GpuLower::current()->kernel()->zeroVal(); + auto consumer_stop_offset = GpuLower::current()->kernel()->zeroVal(); - // Consumer offsets are just zero. - start_offsets.push_back(ir_builder.zeroVal()); - stop_offsets.push_back(ir_builder.zeroVal()); + if (producer_start_offset->isZeroInt() && producer_stop_offset->isZeroInt()) { + return {consumer_start_offset, consumer_stop_offset}; + } + + Val* start_offset = nullptr; + Val* stop_offset = nullptr; + + // In the normal case, take the minimum of the start and the + // maximum of the stop offsets. If there's no padding, the producer + // offset must be always larger than the consumer + // offset. So, the consumer and produce offsets can be always used + // for the start and stop offsets, respectively. + const auto pad_left = + consumer_tv->definition()->as()->padWidth()[root_axis_pos][0]; + const auto pad_right = + consumer_tv->definition()->as()->padWidth()[root_axis_pos][1]; + const auto window_size = + consumer_tv->definition()->as()->windowShape()[root_axis_pos]; - // Adds producer offsets if they are not zero. - if (!producer_start_offset->isZeroInt()) { - start_offsets.push_back(producer_start_offset); + // consumer index: index + // producer index: index + window_index - pad_left + // + // consumer extent: ext + // producer extent: ext + window_size - 1 - pad_left - pad_right + // + // consumer stop pred: index < ext + // producer stop pred: index + window_index - pad_left < ext + window_size - 1 + // - pad_left - pad_right + // -> index + window_index - pad_left - (window_size - 1 - + // pad_left - pad_right) < ext + // -> index + window_index - (window_size - 1 - pad_right) < + // ext + // + // consumer start pred: index >= 0 + // producer start pred: index + window_index - pad_left >= 0 + + const auto producer_ext_adj = window_size - 1 - pad_left - pad_right; + producer_stop_offset = SimplifyingIrBuilder::subExpr( + producer_stop_offset, + SimplifyingIrBuilder::create(producer_ext_adj)); + + // As commented above, when pad_left is zero, the consumer predicate + // is always more restrictive than the producer predicate. + if (pad_left == 0) { + start_offset = consumer_start_offset; + } else { + start_offset = SimplifyingIrBuilder::minExpr( + consumer_start_offset, producer_start_offset); } - if (!producer_stop_offset->isZeroInt()) { - stop_offsets.push_back(producer_stop_offset); + // As commented above, when pad_right is zero, the consumer + // predicate is always more restrictive than the producer + // predicate. + if (pad_right == 0) { + stop_offset = consumer_stop_offset; + } else { + stop_offset = SimplifyingIrBuilder::maxExpr( + consumer_stop_offset, producer_stop_offset); } + + TORCH_INTERNAL_ASSERT(start_offset != nullptr); + TORCH_INTERNAL_ASSERT(stop_offset != nullptr); + + return {start_offset, stop_offset}; } // Get the start and stop limit offsets that define the valid range to @@ -2530,18 +2466,16 @@ void adjustStartAndStopOffsetsForGather( // stop that's different from extent. Also, when IterDomain has halo, // the actual offsets of the logical start and stop positions are // shifted. -std::pair getStartAndStopLimitOffsets( +std::pair getStartAndStopLimitOffsets( IterDomain* consumer_id, bool padding_predicate, bool non_divisible_pred) { const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); TORCH_INTERNAL_ASSERT(consumer_id != nullptr); - kir::Val* start_limit = gpu_lower->lowerValue(consumer_id->start()); - kir::Val* stop_limit = - ir_builder.negExpr(gpu_lower->lowerValue(consumer_id->stopOffset())); + Val* start_limit = consumer_id->start(); + Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset()); if (!non_divisible_pred) { AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_id); @@ -2554,12 +2488,14 @@ std::pair getStartAndStopLimitOffsets( // [0, left halo)[start_limit, stop_limit)[0, right halo) // if (!padding_predicate) { - start_limit = ir_builder.addExpr(start_limit, halo_info.width(0)); - stop_limit = ir_builder.addExpr(stop_limit, halo_info.width(0)); + start_limit = + SimplifyingIrBuilder::addExpr(start_limit, halo_info.width(0)); + stop_limit = + SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width(0)); } else { // In case of the padding predicate, the whole range, including both left // and right halo regions, is computed. - stop_limit = ir_builder.addExpr(stop_limit, halo_info.width()); + stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width()); } } else { // For non-divisible predicates, the index must be predicated such @@ -2568,28 +2504,26 @@ std::pair getStartAndStopLimitOffsets( // isn't a root domain. if (gpu_lower->haloInfo().hasHaloWidth(consumer_id)) { auto halo = gpu_lower->haloInfo().getHaloWidth(consumer_id); - stop_limit = ir_builder.addExpr(stop_limit, halo); + stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo); } } return {start_limit, stop_limit}; } -// Return an index map for a predicate reference tensor. Two different +// Return an IndexCompute for a predicate reference tensor. Two different // maps are used when generating predicates for unswitched expressions // as start and stop conditions need to use different loop-to-index // mappings. -std::unordered_map getPredicateReferenceIndexing( +auto getPredicateReferenceIndexing( const std::vector& loops, const ReferenceTensor& reference, kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, bool start) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - auto reference_domain = reference.domain; - std::unordered_map loop_to_ind_map; + std::unordered_map loop_to_ind_map; std::transform( loops.begin(), @@ -2606,7 +2540,7 @@ std::unordered_map getPredicateReferenceIndexing( // vectorized loop should be like this. bool vectorized_pred = - unswitch_or_vec_loop->iter_domain()->parallelType() == + unswitch_or_vec_loop->iter_domain()->getParallelType() == ParallelType::Vectorize; TORCH_INTERNAL_ASSERT( @@ -2614,12 +2548,11 @@ std::unordered_map getPredicateReferenceIndexing( "Invalid reference generated."); bool within_unswitch = false; - const auto one = ir_builder.oneVal(); for (const auto loop_i : c10::irange(loops.size())) { auto loop = loops[loop_i]; auto loop_id = loop->iter_domain(); - auto loop_pt = loop_id->parallelType(); + auto loop_pt = loop_id->getParallelType(); auto ref_id = reference_domain->axis(loop_i); if (loop == unswitch_or_vec_loop) { @@ -2668,20 +2601,21 @@ std::unordered_map getPredicateReferenceIndexing( if (loop->stop() == loop_id->extent()) { loop_to_ind_map[loop] = loop->start(); } else if (start) { - loop_to_ind_map[loop] = ir_builder.zeroVal(); + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); } else { // Note that the parallel dimension is used rather than // loop-stop(). See the above comment. - loop_to_ind_map[loop] = ir_builder.subExpr( - gpu_lower->parallelDimensionMap().get(loop_pt), - ir_builder.create(1)); + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + GpuLower::current()->parallelDimensionMap().get(loop_pt), + GpuLower::current()->kernel()->zeroVal()); } } else if (start) { - loop_to_ind_map[loop] = ir_builder.zeroVal(); + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); } else { // Similar to the above, loop_id()->extent() is // used here instead of loop->stop(). See the above comment. - loop_to_ind_map[loop] = ir_builder.subExpr(loop_id->extent(), one); + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + loop_id->extent(), GpuLower::current()->kernel()->oneVal()); } } @@ -2693,9 +2627,27 @@ std::unordered_map getPredicateReferenceIndexing( } } + if (double_buffer_axis != nullptr) { + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + double_buffer_axis, loops, true); + if (db_loop != nullptr) { + auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop); + TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end()); + auto cur_index = loop_to_ind_map_it->second; + // if cur_index is not the same as the index of db_loop, it must + // be true that that index has been modified to support + // unswitch. In that case, it is not necessary to move ahead the + // index for double buffering. + if (cur_index == db_loop->index()) { + loop_to_ind_map[db_loop] = IrBuilder::addExpr( + cur_index, GpuLower::current()->kernel()->oneVal()); + } + } + } + // Add magic zero to a loop pretty far inside in indexing - kir::IterDomain* magic_zero_loop = nullptr; - std::unordered_map ref_id_to_ind_map; + IterDomain* magic_zero_loop = nullptr; + std::unordered_map ref_id_to_ind_map; // Due to rfactor/initialization reference_domain may be bigger than loop nest // structure TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); @@ -2703,19 +2655,19 @@ std::unordered_map getPredicateReferenceIndexing( auto loop = loops[loop_i]; auto ind = loop_to_ind_map[loops[loop_i]]; auto ref_axis = reference_domain->axis(loop_i); - auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as(); if (Index::protectWithMagicZero(loop, ref_axis, ind)) { - magic_zero_loop = kir_ref_axis; + magic_zero_loop = ref_axis; } - ref_id_to_ind_map[kir_ref_axis] = loop_to_ind_map[loop]; + ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loop]; } if (ref_id_to_ind_map.count(magic_zero_loop)) { auto& ind = ref_id_to_ind_map[magic_zero_loop]; if (!ind->isConstScalar()) { - ind = ir_builder.addExpr(ind, ir_builder.magicZeroVal()); + ind = SimplifyingIrBuilder::addExpr( + ind, GpuLower::current()->kernel()->magicZeroVal()); } } @@ -2729,7 +2681,7 @@ std::unordered_map getPredicateReferenceIndexing( ref_self_map.insert({id, id}); }); - std::unordered_map reference_halo_extent_map = + std::unordered_map reference_halo_extent_map = getReferenceHaloExtentMap(reference, ref_self_map); // Index into the reference tensor @@ -2741,64 +2693,55 @@ std::unordered_map getPredicateReferenceIndexing( {}, reference_halo_extent_map); - return index_compute.indexMap(); + return index_compute; } // Get the offsets for the start and stop predicates. The offsets // are to be added to the index. -std::pair, std::vector> getStartAndStopOffsets( +std::pair getStartAndStopOffsets( IterDomain* consumer_id, TensorView* consumer_tv, const ReferenceTensor& reference, - const std::unordered_map& ref_start_index_map, - const std::unordered_map& ref_stop_index_map, + const std::unordered_map& consumer_start_index_map, + const std::unordered_map& consumer_stop_index_map, bool padding_predicate, bool unswitch, bool non_divisible_pred) { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - // By default, the offsets for the start and stop predicates are - // just zero. - std::vector start_offsets{ir_builder.zeroVal()}; - std::vector stop_offsets{ir_builder.zeroVal()}; - - if (consumer_id == nullptr) { - return {start_offsets, stop_offsets}; + // just zero. All halo-related adjustments are done at root domains, + // so consumer_id is not a root domain, no adjustment is required. + if (consumer_id->definition() != nullptr && !non_divisible_pred) { + return { + GpuLower::current()->kernel()->zeroVal(), + GpuLower::current()->kernel()->zeroVal()}; } auto consumer_def = consumer_tv->definition(); + Val* start_offset = GpuLower::current()->kernel()->zeroVal(); + Val* stop_offset = GpuLower::current()->kernel()->zeroVal(); + // These adjustments are not required when predicating non-divisible splits if (!non_divisible_pred) { if (consumer_def->isA()) { - adjustStartAndStopOffsetsForShift( - start_offsets, - stop_offsets, - consumer_tv, - consumer_id, - padding_predicate); + std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForShift( + consumer_tv, consumer_id, padding_predicate); } else if (consumer_def->isA()) { - adjustStartAndStopOffsetsForGather( - start_offsets, - stop_offsets, + std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForGather( consumer_tv, consumer_id, - reference, - ref_start_index_map, - ref_stop_index_map, + consumer_start_index_map, + consumer_stop_index_map, padding_predicate); } // Adjustment for partial split - auto partial_split_offset = getGlobalConsumerOffsetWithPartialSplit( - gpu_lower->lowerValue(consumer_id)->as()); - for (auto& start_offset : start_offsets) { - start_offset = ir_builder.addExpr(start_offset, partial_split_offset); - } - for (auto& stop_offset : stop_offsets) { - stop_offset = ir_builder.addExpr(stop_offset, partial_split_offset); - } + auto partial_split_offset = + getGlobalConsumerOffsetWithPartialSplit(consumer_id); + start_offset = + SimplifyingIrBuilder::addExpr(start_offset, partial_split_offset); + stop_offset = + SimplifyingIrBuilder::addExpr(stop_offset, partial_split_offset); // If generating a predicate for unswitch, adjust the stop offset to // accommodate the addition of halo to the loop stop. See the @@ -2808,9 +2751,8 @@ std::pair, std::vector> getStartAndStopOffsets !padding_predicate, "Unswitch should not use the padding predicate"); auto stop_unswitch_offset = getUnswitchStopOffset(consumer_id, consumer_tv); - for (auto& stop_offset : stop_offsets) { - stop_offset = ir_builder.addExpr(stop_offset, stop_unswitch_offset); - } + stop_offset = + SimplifyingIrBuilder::addExpr(stop_offset, stop_unswitch_offset); } } @@ -2830,39 +2772,49 @@ std::pair, std::vector> getStartAndStopOffsets // index + (start_offset - start_limit) >= 0 // index + (stop_offset - stop_limit) < extent - for (auto& start_offset : start_offsets) { - start_offset = ir_builder.subExpr(start_offset, limits.first); - } - for (auto& stop_offset : stop_offsets) { - stop_offset = ir_builder.subExpr(stop_offset, limits.second); - } + start_offset = SimplifyingIrBuilder::subExpr(start_offset, limits.first); + stop_offset = SimplifyingIrBuilder::subExpr(stop_offset, limits.second); - return {start_offsets, stop_offsets}; + return {start_offset, stop_offset}; } -bool canOmitStartPredicate(kir::Val* start_offset) { +// A partial value of a start offset is returned if determined to be +// safe. Nullptr is returned if it can be omitted completely. +Val* simplifyStartOffset(Val* start_offset) { // Start predicate can be omitted when start_offset >= 0. - auto offset_val = start_offset->as()->value(); - return offset_val.has_value() && offset_val.value() >= 0; + auto offset_val = start_offset->as()->value(); + if (offset_val.has_value() && offset_val.value() >= 0) { + return nullptr; + } + + // start_offset may look like min(0, window_index - pad). Then, can + // remove min and leave the rhs only. + auto def = dynamic_cast(start_offset->definition()); + if (def != nullptr && def->getBinaryOpType() == BinaryOpType::Min && + def->lhs()->isZeroInt()) { + return def->rhs(); + } + + return start_offset; } bool canOmitStopPredicate( - kir::Val* stop_index, - kir::Val* stop_offset, - kir::IterDomain* kir_contig_id) { + Val* stop_index, + Val* stop_offset, + IterDomain* contig_id) { bool index_simple = stop_index->definition() == nullptr; // The definition may be just adding the magic zero, which can be // effectively considered "simple" if (!index_simple && isProtectedWithMagicZero(stop_index)) { // Make sure the lhs of stop_index is simple. - auto lhs = stop_index->definition()->as()->lhs(); + auto lhs = stop_index->definition()->as()->lhs(); if (lhs->definition() == nullptr) { index_simple = true; } } // Omit only when both the index and extent are "simple". - if (!(index_simple && kir_contig_id->extent()->definition() == nullptr)) { + if (!(index_simple && contig_id->extent()->definition() == nullptr)) { return false; } @@ -2873,33 +2825,32 @@ bool canOmitStopPredicate( // omitted if extent + halo + stop_offset < extent, i.e., halo + // stop_offset <= 0. - auto stop_offset_val = stop_offset->as()->value(); + auto stop_offset_val = stop_offset->as()->value(); - auto halo_ext = - gpu_lower->haloInfo().getRootAxisInfo(kir_contig_id).width()->value(); + auto halo_ext = gpu_lower->haloInfo().getRootAxisInfo(contig_id).width(); // If they are not compile-time constant, can't prove the // condition. - if (!stop_offset_val.has_value() || !halo_ext.has_value()) { + if (!stop_offset_val.has_value()) { return false; } - if (halo_ext.value() + stop_offset_val.value() > 0) { + if (halo_ext + stop_offset_val.value() > 0) { return false; } // When the domain is parallelized, the parallel dimension must be // exact. Otherwise, there would be extra threads/blocks that need // to be predicated out. - if (isParallelTypeThread(kir_contig_id->parallelType())) { + if (isParallelTypeThread(contig_id->getParallelType())) { if (!gpu_lower->parallelDimensionMap().isExact( - kir_contig_id->parallelType())) { + contig_id->getParallelType())) { return false; } // If the domain has halo, the loop is expanded by the halo // extent, so we can't prove the loop extent is the same as the // parallel dimension. - if (!(halo_ext.has_value() && halo_ext.value() == 0)) { + if (halo_ext != 0) { return false; } } @@ -2912,50 +2863,70 @@ bool canOmitStopPredicate( // Returns predicates and the concrete (by loop map) root domains they cover std::pair, ReferenceTensor> Index:: getReferenceRootPredicates( - const kir::TensorView* kir_consumer_tv, + TensorView* consumer_tv, const std::vector& loops, kir::ForLoop* unswitch_or_vec_loop, bool shift_padding) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates"); const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); + + const bool is_unswitch = unswitch_or_vec_loop != nullptr; // Nothing needs to be done when padding is not required. - if (shift_padding && !needsPadding(kir_consumer_tv->fuserTv())) { + if (shift_padding && !needsPadding(consumer_tv)) { return {{RootPredicateInfo::getFalseInfo()}, ReferenceTensor{}}; } - auto consumer_tv = kir_consumer_tv->fuserTv(); - // Get a reference tensor replayed as existing loop structure ReferenceTensor reference = IndexReferenceReplay::getReference(loops); // Generate halo information for reference. updateHaloInfoForReference(reference, consumer_tv); + const auto ref_2_consumer = indexMapReferenceTo( + consumer_tv, gpu_lower->caIndexMap(), reference.concrete_to_id); + + const auto reference_halo_extent_map = + getReferenceHaloExtentMap(reference, ref_2_consumer); + + auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv); + // Both start and stop positions may need to be predicated. Indexing // differs when generating predicates for unswitch. // NOTE: If we could find-and-replace KIR nodes, we could just // generate one index map, clone it and replace the loop-to-index // mappings of unswitched loops for the start predicate. - const auto ref_stop_index_map = getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, false); - // If not unswitch, share the same indexing map as the stop index map - const auto& ref_start_index_map = unswitch_or_vec_loop != nullptr - ? getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, true) - : ref_stop_index_map; - - auto ref_2_consumer = indexMapReferenceTo( - consumer_tv, gpu_lower->caIndexMap(), reference.concrete_to_id); + auto ref_stop_indexing = getPredicateReferenceIndexing( + loops, reference, unswitch_or_vec_loop, db_axis, false); + const auto consumer_stop_indexing = ref_stop_indexing.updateIndexCompute( + consumer_tv->domain(), + ref_2_consumer, + std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), + reference_halo_extent_map); + const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap(); + + // If not unswitch, share the same indexing map as the stop index + // map + std::unordered_map consumer_start_index_map; + if (is_unswitch) { + auto ref_start_indexing = getPredicateReferenceIndexing( + loops, reference, unswitch_or_vec_loop, db_axis, true); + const auto consumer_start_indexing = ref_start_indexing.updateIndexCompute( + consumer_tv->domain(), + ref_2_consumer, + std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), + reference_halo_extent_map); + consumer_start_index_map = consumer_start_indexing.indexMap(); + } else { + consumer_start_index_map = consumer_stop_index_map; + } // Get the contiguous ids we need to generate predicates for - auto contig_id_infos = - getPredicateContigIds(reference, consumer_tv, ref_2_consumer); + auto contig_id_infos = getPredicateContigIds(consumer_tv); auto non_divisible_splits = - getNonDivisibleReferenceDomainsToPredicate(consumer_tv, reference); + getNonDivisibleConsumerDomainsToPredicate(consumer_tv); contig_id_infos.insert( contig_id_infos.end(), non_divisible_splits.begin(), @@ -2972,52 +2943,22 @@ std::pair, ReferenceTensor> Index:: } auto root_ids = contig_id_entry.covered_ids; - auto kir_contig_id = - gpu_lower->lowerValue(contig_id)->as(); - const auto ref_stop_indexing_it = ref_stop_index_map.find(kir_contig_id); + const auto consumer_stop_indexing_it = + consumer_stop_index_map.find(contig_id); - // First condition below is due to broadcasts in consumers of consumer that - // are not in consumer there can be unresolved indexing in the reference - // tensor. This can happen when we have something like: TV3[i1o*i2, i1i] and - // TV1[i2] where tv3 and tv1 share their outer dimension. i1 will be part of - // reference tensors root domain, but when indexing into TV1 there aren't - // enough indices to resolve it. - // - // The condition also happens with Misaligned predicates, where + // First condition below happens with Misaligned predicates, where // inner-most vectorized loops are not included in the loops // parameter. Predicates involving vectorized loops are separately // generated in lower_misaligned_vectorization. // - // It can also happens with rfactored reductions. The reference - // tensor may include rfactored domains, so the contig id may be - // a root domain of the reference, not a rfactor root. Since - // there is no loop for rfactor domains, there's no indexing - // mapping for root domains. This seems safe as it can only happen - // with rfactor and rfactored tensors do not need predicates. - // // Second condition is simply to avoid predication on broadcasting axes as // it's not required. - if (ref_stop_indexing_it == ref_stop_index_map.end() || - ref_stop_indexing_it->second->isZeroInt()) { + if (consumer_stop_indexing_it == consumer_stop_index_map.end() || + consumer_stop_indexing_it->second->isZeroInt()) { continue; } - // Find a corresponding consumer root id if exists. Used to - // support shift. If a contig_id is a merged non-root domain, nothing - // is required to do for shift as shift-related domains are - // excluded from contig domains. - IterDomain* consumer_id = nullptr; - if (contig_id->definition() == nullptr || - contig_id_entry.is_non_divisible_split) { - auto it = ref_2_consumer.find(contig_id); - if (it != ref_2_consumer.end()) { - consumer_id = it->second; - } else { - continue; - } - } - RootPredicateInfo info; // Compute offsets for start and stop predicate. For non-shift, @@ -3032,53 +2973,50 @@ std::pair, ReferenceTensor> Index:: // The final predicates will look like: // (index + start_offset) >= 0 && (index + stop_offset) < extent. - std::tie(info.start_offsets_, info.stop_offsets_) = getStartAndStopOffsets( - consumer_id, + std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets( + contig_id, consumer_tv, reference, - ref_start_index_map, - ref_stop_index_map, + consumer_start_index_map, + consumer_stop_index_map, shift_padding, unswitch_or_vec_loop != nullptr, contig_id_entry.is_non_divisible_split); - auto stop_index = ref_stop_indexing_it->second; - auto start_index = ref_start_index_map.at(kir_contig_id); + auto stop_index = consumer_stop_indexing_it->second; + auto start_index = consumer_start_index_map.at(contig_id); // Build predicates for start positions as: // start_index + start_offset >= 0 - for (auto start_offset : info.start_offsets_) { - if (canOmitStartPredicate(start_offset)) { - info.start_predicates_.push_back(ir_builder.trueVal()); - continue; - } + auto start_offset = simplifyStartOffset(info.start_offset_); + if (start_offset == nullptr) { + info.start_predicate_ = GpuLower::current()->kernel()->trueVal(); + } else { auto offsetted_start_index = - ir_builder.addExpr(start_index, start_offset); - auto pred = - ir_builder.geExpr(offsetted_start_index, ir_builder.zeroVal()) - ->as(); - info.start_predicates_.push_back(pred); + SimplifyingIrBuilder::addExpr(start_index, start_offset); + auto start_pred = + SimplifyingIrBuilder::geExpr( + offsetted_start_index, GpuLower::current()->kernel()->zeroVal()) + ->as(); + info.start_predicate_ = start_pred; } // Build predicates for stop positions as: // stop_index + stop_offset < IterDomain::extent - for (auto stop_offset : info.stop_offsets_) { - if (canOmitStopPredicate(stop_index, stop_offset, kir_contig_id)) { - info.stop_predicates_.push_back(ir_builder.trueVal()); - continue; - } - auto offsetted_stop_index = ir_builder.addExpr(stop_index, stop_offset); - auto pred = - ir_builder.ltExpr(offsetted_stop_index, kir_contig_id->extent()) - ->as(); - info.stop_predicates_.push_back(pred); + auto stop_offset = info.stop_offset_; + if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) { + info.stop_predicate_ = GpuLower::current()->kernel()->trueVal(); + } else { + auto offsetted_stop_index = + SimplifyingIrBuilder::addExpr(stop_index, stop_offset); + auto stop_pred = SimplifyingIrBuilder::ltExpr( + offsetted_stop_index, contig_id->extent()) + ->as(); + info.stop_predicate_ = stop_pred; } - // Transform ids from reference to concrete and consumer domains - // (based on loop compute at map) - for (auto ref_id : contig_id_entry.covered_ids) { - info.root_ids_.insert(reference.id_to_concrete.at(ref_id)); - info.consumer_ids_.insert(ref_2_consumer.at(ref_id)); + for (auto consumer_id : contig_id_entry.covered_ids) { + info.root_ids_.insert(consumer_id); } pred_info_vec.emplace_back(info); } @@ -3089,7 +3027,7 @@ std::pair, ReferenceTensor> Index:: bool Index::protectWithMagicZero( kir::ForLoop* loop, IterDomain* reference_domain, - kir::Val* ind) { + Val* ind) { bool ref_dom_simple = (reference_domain == nullptr ? true : reference_domain->definition() != nullptr); @@ -3100,16 +3038,9 @@ bool Index::protectWithMagicZero( } RootPredicateInfo RootPredicateInfo::getFalseInfo() { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - RootPredicateInfo info; - info.start_predicates_.push_back(ir_builder.falseVal()); - info.stop_predicates_.push_back(ir_builder.falseVal()); - // These are just placeholder. When the predicate is false, the - // offset should not be used. - info.start_offsets_.push_back(nullptr); - info.stop_offsets_.push_back(nullptr); + info.start_predicate_ = GpuLower::current()->kernel()->falseVal(); + info.stop_predicate_ = GpuLower::current()->kernel()->falseVal(); return info; } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 83536067c19..27f1c911bde 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -69,30 +69,30 @@ class IndexCompute : public BackwardVisitor { void handle(Expr*) override; // return extent_map_[id] if exists, else return id->extent() - kir::Val* getExtent(kir::IterDomain* id); + Val* getExtent(IterDomain* id); //! True if a domain is not used to index - bool isZero(kir::IterDomain* id) const; + bool isZero(IterDomain* id) const; //! True if any dependent of a domain is not used to index - bool hasZeroMerged(kir::IterDomain* id) const; + bool hasZeroMerged(IterDomain* id) const; // Tensor domain we're mapping back to root const TensorDomain* td_; // NOLINT // Map we update as we propagate backward, containing all IDs in the // propagation. Initial indices are mapped with this map at tv->domain() - // and are back propagated to tv->rootDomain(). This index_map_ keeps the + // and are back propagated to tv->getRootDomain(). This index_map_ keeps the // indices at intermediate IterDomain's in that back propagation. - std::unordered_map index_map_; // NOLINT + std::unordered_map index_map_; // NOLINT // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to // the extent I0*I1. Also contains updated extents if we merge in a 0 index. // See zero_merged_in_. - std::unordered_map extent_map_; // NOLINT + std::unordered_map extent_map_; // NOLINT // Keeps track of domains that do not contribute to indexing - std::unordered_set zero_domains_; // NOLINT + std::unordered_set zero_domains_; // NOLINT // This set keeps track of IterDomain's that have had a zero index merged into // them. This happens if we do something like tv->axis(0)->split(4) then @@ -100,47 +100,46 @@ class IndexCompute : public BackwardVisitor { // indexing would be (0, i) then when we do the backward computation that zero // and i would attempt to be merged together. We handle indices like these // specially. - std::unordered_set zero_merged_in_; + std::unordered_set zero_merged_in_; // IDs that are a result of contiguous merges - std::unordered_set contig_ids; + std::unordered_set contig_ids; // Mentions if we should propagate an index down a particular IterDomain path // if there's an option - std::unordered_set preferred_paths_; + std::unordered_set preferred_paths_; // Map from IterDomains to halo-extended extents in corresponding // reference tensor - std::unordered_map reference_halo_extent_map_; + std::unordered_map reference_halo_extent_map_; public: - const std::unordered_map& indexMap() const { + const std::unordered_map& indexMap() const { return index_map_; } - const std::unordered_map& extentMap() const { + const std::unordered_map& extentMap() const { return extent_map_; } - const std::unordered_set& zeroDomains() const { + const std::unordered_set& zeroDomains() const { return zero_domains_; } - const std::unordered_set& zeroMergedIn() const { + const std::unordered_set& zeroMergedIn() const { return zero_merged_in_; } // Propagate back from _td using initial_index_map IndexCompute( const TensorDomain* _td, - std::unordered_map initial_index_map, - std::unordered_map _extent_map, - std::unordered_set zero_domains, - std::unordered_set _zero_merged_in, + std::unordered_map initial_index_map, + std::unordered_map _extent_map, + std::unordered_set zero_domains, + std::unordered_set _zero_merged_in, const std::vector& _root_contiguity, - std::unordered_set preferred_paths = {}, - std::unordered_map - reference_halo_extent_map = {}); + std::unordered_set preferred_paths = {}, + std::unordered_map reference_halo_extent_map = {}); // Updates index_map, extent_map, and zero_merged_in based on id_map and // returns a new IndexCompute ready to be used. @@ -148,8 +147,8 @@ class IndexCompute : public BackwardVisitor { const TensorDomain* new_td, const std::unordered_map& id_map, const std::vector& _root_contiguity, - const std::unordered_map& - reference_halo_extent_map = {}); + const std::unordered_map& reference_halo_extent_map = + {}); virtual void run(); }; @@ -159,10 +158,10 @@ class IndexSwizzle : public IndexCompute { public: IndexSwizzle( const TensorView* tv, - std::unordered_map initial_index_map, - std::unordered_map extent_map, - std::unordered_set zero_domains, - std::unordered_set zero_merged_in); + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in); void run() override; @@ -183,51 +182,45 @@ class RootPredicateInfo { friend class Index; public: - const auto& startPredicates() const { - return start_predicates_; + const auto& startPredicate() const { + return start_predicate_; } - auto& startPredicates() { - return start_predicates_; + auto& startPredicate() { + return start_predicate_; } - const auto& startOffsets() const { - return start_offsets_; + const auto& startOffset() const { + return start_offset_; } - const auto& stopPredicates() const { - return stop_predicates_; + const auto& stopPredicate() const { + return stop_predicate_; } - const auto& stopOffsets() const { - return stop_offsets_; + const auto& stopOffset() const { + return stop_offset_; } const auto& rootIds() const { return root_ids_; } - const auto& consumerIds() const { - return consumer_ids_; - } - //! Return a false RootPredicateInfo, i.e., both start and stop //! predicates are false. static RootPredicateInfo getFalseInfo(); private: - // prdicates for lower end - std::vector start_predicates_; - // prdicates for upper end - std::vector stop_predicates_; - // Offsets of the start predicate - std::vector start_offsets_; - // Offsets of the stop predicate - std::vector stop_offsets_; + // prdicate for lower end + Bool* start_predicate_ = nullptr; + // prdicate for upper end + Bool* stop_predicate_ = nullptr; + // Offset of the start predicate + Val* start_offset_ = nullptr; + // Offset of the stop predicate + Val* stop_offset_ = nullptr; // Track which roots have been handled by the generated predicates std::unordered_set root_ids_; - // Consumer IDs that correspond to root_ids_ - std::unordered_set consumer_ids_; }; // Simple interface for IndexCompute @@ -236,24 +229,24 @@ class RootPredicateInfo { class Index { private: // Producer indexing if it's in shared or local memory - static std::vector getNonGlobalProducerStridedIndices( + static std::vector getNonGlobalProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops); // Consumer indexing if it's in shared or local memory - static std::vector getNonGlobalConsumerStridedIndices( + static std::vector getNonGlobalConsumerStridedIndices( const TensorView* consumer, const std::vector& loops); // Producer if it's in global memory - static std::vector getGlobalProducerStridedIndices( + static std::vector getGlobalProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops); // Consumer indexing if it's in global memory - static std::vector getGlobalConsumerStridedIndices( + static std::vector getGlobalConsumerStridedIndices( const TensorView* consumer, const std::vector& loops); @@ -276,7 +269,7 @@ class Index { //! root domain of a producer tensor. The size of the returned //! vector is guaranteed to be equal to the number of axes of the //! indexing root domain. - static std::vector getProducerStridedIndices( + static std::vector getProducerStridedIndices( TensorView* producer, const TensorView* consumer, const std::vector& loops); @@ -285,7 +278,7 @@ class Index { //! root domain of a consumer tensor. The size of the returned //! vector is guaranteed to be equal to the number of axes of the //! indexing root domain. - static std::vector getConsumerStridedIndices( + static std::vector getConsumerStridedIndices( const TensorView* consumer, const std::vector& loops); @@ -313,7 +306,7 @@ class Index { //! vectorized loop. static std::pair, ReferenceTensor> getReferenceRootPredicates( - const kir::TensorView* kir_consumer_tv, + TensorView* consumer_tv, const std::vector& loops, kir::ForLoop* unswitch_or_vec_loop, bool padding_predicate); @@ -328,7 +321,7 @@ class Index { static bool protectWithMagicZero( kir::ForLoop* loop, IterDomain* reference_domain = nullptr, - kir::Val* ind = nullptr); + Val* ind = nullptr); }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp index fcd0a8937ed..27e5b93e94e 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp @@ -1,11 +1,10 @@ #include #include +#include #include #include #include -#include -#include namespace torch { namespace jit { @@ -41,16 +40,12 @@ IterDomain* IndexReferenceReplay::idCopy(IterDomain* id) { // reduction. All we care about are the transformations, and trying to make // sure we track correctly a replaying with consistent reduction/broadcast // domains is challenging and unnecessary. - auto copied_id = - new IterDomain(id->start(), id->extent(), id->getParallelType()); + auto copied_id = IrBuilder::create( + id->container(), id->start(), id->extent(), id->getParallelType()); replayed_ids_.emplace_back(copied_id); return copied_id; } -IterDomain* IndexReferenceReplay::toFusionID(kir::IterDomain* kir_id) { - return ca_map_.toFusion(kir_id); -} - IterDomain* IndexReferenceReplay::toConcrete(IterDomain* id) { return ca_map_.getConcreteMappedID(id); } @@ -70,7 +65,8 @@ void IndexReferenceReplay::handle(Split* split) { } // Replay the provided split operation and add it to the reference DAG - new Split( + IrBuilder::create( + split->container(), ref_outer, ref_inner, ref_in, @@ -101,7 +97,7 @@ void IndexReferenceReplay::handle(Merge* merge) { } // Replay the provided merge operation and add it to the reference DAG - new Merge(ref_out, ref_outer, ref_inner); + IrBuilder::create(merge->container(), ref_out, ref_outer, ref_inner); // Mark producers and consumers ref_id_consumed_.emplace(ref_outer); @@ -149,7 +145,7 @@ TensorDomain* IndexReferenceReplay::computeReplay() { loop_structure_.begin(), loop_structure_.end(), std::back_inserter(domain_ids), - [this](kir::ForLoop* fl) { return toFusionID(fl->iter_domain()); }); + [](kir::ForLoop* fl) { return fl->iter_domain(); }); // IterVisitor based traversals don't work because we don't have all outputs. // backward traversal's traverseFrom(domain_ids) will throw "Invalid backward @@ -194,7 +190,7 @@ TensorDomain* IndexReferenceReplay::computeReplay() { // Construct a tensor that's representitive of the replayed loop structure. std::vector loops_replayed_domain; for (auto loop : loop_structure_) { - auto loop_id = toFusionID(loop->iter_domain()); + auto loop_id = loop->iter_domain(); // Map to loops with the loop map, but make sure the replayed id is actually // a leaf in the replay. auto ref_id_it = std::find_if( @@ -222,7 +218,7 @@ TensorDomain* IndexReferenceReplay::computeReplay() { loops_replayed_domain.begin(), loops_replayed_domain.end(), [](IterDomain* id) { return id->definition() != nullptr; })) { - auto domain = new TensorDomain( + auto domain = IrBuilder::create( // If there was no replay only return a domain with a root domain. loops_replayed_domain); return domain; @@ -257,8 +253,9 @@ TensorDomain* IndexReferenceReplay::computeReplay() { } // Create and return the reference. - auto domain = new TensorDomain( - {root_domain_ids.begin(), root_domain_ids.end()}, + auto domain = IrBuilder::create( + std::vector( + root_domain_ids.begin(), root_domain_ids.end()), loops_replayed_domain); return domain; } @@ -266,26 +263,30 @@ TensorDomain* IndexReferenceReplay::computeReplay() { IndexCompute getReferenceIndexing( const std::vector& loop_structure, - TensorDomain* reference_tensor) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - + TensorDomain* reference_tensor, + kir::ForLoop* double_buffer_loop) { // Create a simple index mapping from loop iter domains to their local index. // This is only applicable to global memory buffers. - std::unordered_map initial_index_map; + std::unordered_map initial_index_map; TORCH_INTERNAL_ASSERT(loop_structure.size() <= reference_tensor->nDims()); int magic_zero_loop = -1; for (const auto loop_i : c10::irange(loop_structure.size())) { auto ref_axis = reference_tensor->axis(loop_i); - auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as(); auto loop = loop_structure[loop_i]; auto ind = loop->index(); - ; - initial_index_map[kir_ref_axis] = ind; + initial_index_map[ref_axis] = ind; if (loop->vectorize()) { - initial_index_map[kir_ref_axis] = ir_builder.create(0); + initial_index_map[ref_axis] = GpuLower::current()->kernel()->zeroVal(); + } else if (double_buffer_loop == loop) { + // This version of getReferenceIndexing is only used for + // indexing global tensors. When indexing global producers, the + // index for a double buffered loop needs to be incremented. The + // parameter double_buffer_loop should be nullptr when indexing + // global consumers tensors. + initial_index_map[ref_axis] = + IrBuilder::addExpr(ind, GpuLower::current()->kernel()->oneVal()); } if (Index::protectWithMagicZero(loop, ref_axis, ind)) { @@ -295,10 +296,9 @@ IndexCompute getReferenceIndexing( // Add magic zero to a fairly inner most index if (magic_zero_loop >= 0) { - auto ref_id = gpu_lower->lowerValue(reference_tensor->axis(magic_zero_loop)) - ->as(); - initial_index_map[ref_id] = ir_builder.addExpr( - initial_index_map[ref_id], ir_builder.magicZeroVal()); + auto ref_id = reference_tensor->axis(magic_zero_loop); + initial_index_map[ref_id] = IrBuilder::addExpr( + initial_index_map[ref_id], FusionGuard::getCurFusion()->magicZeroVal()); } // Send to the other version of reference indexing that directly takes the @@ -310,19 +310,17 @@ IndexCompute getReferenceIndexing( IndexCompute getReferenceIndexing( const std::vector& loop_structure, TensorDomain* reference_tensor, - std::unordered_map index_map, - std::unordered_set zero_domains, + std::unordered_map index_map, + std::unordered_set zero_domains, std::unordered_set preferred_paths, - std::unordered_map halo_extent_map) { - auto gpu_lower = GpuLower::current(); - + std::unordered_map halo_extent_map) { // I thought this might be necesasry, but turns out it's not. I think it's // because of the root ordering above, however leaving it in case we find // out it is necessary in some cases. At the time of commiting, cuda-memcheck // passed without this. // - // std::unordered_map reference_extent_map; for (auto loop : loop_structure) { + // std::unordered_map reference_extent_map; for (auto loop : loop_structure) { // // If there's a broadcast merged in the for loop ID we want to track its // // extent // auto inputs = InputsOf::outputs( @@ -342,16 +340,6 @@ IndexCompute getReferenceIndexing( // } // } - // Convert to preferred_path to kir::IterDomain for IndexCompute - std::unordered_set kir_preferred_path; - std::transform( - preferred_paths.begin(), - preferred_paths.end(), - std::inserter(kir_preferred_path, kir_preferred_path.begin()), - [&gpu_lower](IterDomain* id) { - return gpu_lower->lowerValue(id)->as(); - }); - IndexCompute compute( reference_tensor, index_map, // NOLINT @@ -359,9 +347,9 @@ IndexCompute getReferenceIndexing( // in this function {}, zero_domains, - std::unordered_set(), + std::unordered_set(), reference_tensor->contiguity(), - kir_preferred_path, + preferred_paths, halo_extent_map); compute.run(); diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.h b/torch/csrc/jit/codegen/cuda/index_reference_replay.h index c4626213e76..fcb8e1f94e8 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.h +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -34,10 +34,6 @@ class IndexReferenceReplay : public OptInDispatch { // Make a new id for the reference replay based on the provided id IterDomain* idCopy(IterDomain* id); - // Use the compute at map to get the fusion IterDomain from the - // kir::IterDomain - IterDomain* toFusionID(kir::IterDomain* kir_id); - // Return the concrete entry of the non-reference id IterDomain* toConcrete(IterDomain* id); @@ -87,16 +83,17 @@ class IndexReferenceReplay : public OptInDispatch { IndexCompute getReferenceIndexing( const std::vector& loop_structure, TensorDomain* reference_domain, - std::unordered_map index_map, - std::unordered_set zero_domains, + std::unordered_map index_map, + std::unordered_set zero_domains, std::unordered_set preferred_path, - std::unordered_map halo_extent_map = {}); + std::unordered_map halo_extent_map = {}); // Short cut for global TVs. Index into the reference based on all loop indicies // in the loop structure. IndexCompute getReferenceIndexing( const std::vector& loop_structure, - TensorDomain* reference_domain); + TensorDomain* reference_domain, + kir::ForLoop* double_buffer_loop = nullptr); // When indexing there are sometimes an option to propagate an index down // multiple paths. This will return the IterDomains in the history of the diff --git a/torch/csrc/jit/codegen/cuda/instrumentation.cpp b/torch/csrc/jit/codegen/cuda/instrumentation.cpp index 52e16b3a7af..d227df0ab26 100644 --- a/torch/csrc/jit/codegen/cuda/instrumentation.cpp +++ b/torch/csrc/jit/codegen/cuda/instrumentation.cpp @@ -1,6 +1,6 @@ #include -#include +#include #ifdef _WIN32 #include diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index bd54d30811d..d21004ae154 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -15,13 +15,15 @@ C10_DEFINE_bool( C10_DEFINE_bool( torch_jit_nvfuser_horizontal_fusion, true, - "enable single node fusion for nvfuser"); + "enable horizontal fusion for nvfuser"); namespace torch { namespace jit { namespace fuser { namespace cuda { +static std::atomic cuda_fusion_guard_mode{true}; + bool getSingletonFusion() { return FLAGS_torch_jit_nvfuser_singleton_fusion; } @@ -42,8 +44,6 @@ bool setHorizontalFusion(bool value) { return old_value; } -static std::atomic cuda_fusion_guard_mode{true}; - std::atomic& getCudaFusionGuardMode() { return cuda_fusion_guard_mode; } @@ -329,6 +329,220 @@ RegisterOperators reg_guard({ aliasAnalysisFromSchema()), }); +// Infer dynamic axis (-1) in view_sizes given tensor_sizes +bool inferViewShape( + c10::List tensor_sizes, + c10::List view_sizes) { + int64_t dynamic_index = -1; + size_t view_size_num_elements = 1; + for (size_t idx = 0; idx < view_sizes.size(); ++idx) { + if (view_sizes[idx] == -1) { + TORCH_INTERNAL_ASSERT( + dynamic_index == -1, "Only one dimension can by inferred.") + dynamic_index = idx; + } else { + TORCH_INTERNAL_ASSERT(view_sizes[idx] > 0); + view_size_num_elements *= view_sizes[idx]; + } + } + const size_t kNumElements = std::accumulate( + tensor_sizes.begin(), tensor_sizes.end(), 1, std::multiplies<>()); + + if (kNumElements % view_size_num_elements != 0) { + return false; + } + + if (dynamic_index != -1) { + view_sizes[dynamic_index] = kNumElements / view_size_num_elements; + } + + return true; +} + +//! [ Note -- type guard logic in CudaFusionViewGuard ] +//! +//! CudaFusionViewGuard is used to guard input tensors to a `CudaFusionGroup` +//! that contains view operations, so that we would not feed inputs that +//! violate the graph defined in `GraphCache`. +//! +//! output = view(self, view-sizes) +//! +//! View Guard Inputs: +//! 1. self tensor_sizes - dynamic size List[Int] +//! 2. view_sizes - profile_ivalue List[Int] +//! 3. tensor_constraint - Constant List[Int] +//! 4. view_sizes_constraint - Constant List[Int] +//! +//! Things that we check: +//! 1. The #dimensions are the same for self tensor and its constraint +//! 2. The #dimensions are the same for view-sizes and its constraint +//! 3. Self tensor does not violate its constraint +//! a. Queue unrestricted sizes +//! b. Calculate #elements in self tensor +//! 4. view-sizes does not violate its constraint +//! a. Pop unrestricted sizes from queue +//! b. Calculate #elements in view-sizes +//! 5. The #elements is the same for self tensor and view-sizes +//! +//! Constraints: +//! A restricted axis creates a graph constraint, so its sizes is static. +//! An unrestricted axis is allowed to have a dynamic size, if it is consistent +//! between self tensor and view-sizes. It is marked with -1 in the constraint. +//! Only iterDomains with the Keep transform are dynamic. All other transforms +//! create a static constraint. +//! +bool checkViewGuard( + c10::List tensor_sizes, + c10::List view_sizes, + c10::List tensor_constraint, + c10::List view_sizes_constraint) { + // 1: Num Dimensions Check + if (tensor_constraint.size() != tensor_sizes.size() || + view_sizes_constraint.size() != view_sizes.size()) { + return false; + } + + // If axis allows dynamic sizes, then add tensor size to this queue. + // For dynamic axes in view_sizes, check that it is consistent with + // the corresponding tensor size. + std::queue dynamic_axis_queue; + + // 2. Tensor Static Check + int64_t tensor_size_product = 1; + for (const auto idx : c10::irange(tensor_sizes.size())) { + if (tensor_constraint[idx] == -1) { + dynamic_axis_queue.push(tensor_sizes[idx]); + } else if (tensor_constraint[idx] != tensor_sizes[idx]) { + return false; + } + tensor_size_product *= tensor_sizes[idx]; + } + + // 3. View-Sizes Static Check + int64_t view_size_product = 1; + for (const auto idx : c10::irange(view_sizes.size())) { + auto dynamic_size = (view_sizes_constraint[idx] == -1) + ? dynamic_axis_queue.front() + : view_sizes_constraint[idx]; + if (dynamic_size != view_sizes[idx]) { + return false; + } + view_size_product *= dynamic_size; + if (view_sizes_constraint[idx] == -1) { + dynamic_axis_queue.pop(); + } + } + + // 4. Check view invariant + // The number of elements in the input and output tensors are the same. + return tensor_size_product == view_size_product; +} + +//! +//! CudaFusionViewGuard Example Graph: +//! +//! graph(%self : __torch__.BiasViewRelu, +//! %inputs.1 : Tensor): +//! %2 : int = prim::Constant[value=-1]() # dynamic_bvg.py:50:40 +//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25 +//! %4 : NoneType = prim::Constant() +//! %5 : int[] = prim::Constant[value=[2, 3]]() +//! %6 : int[] = aten::size(%inputs.1) # dynamic_bvg.py:50:25 +//! %7 : int[] = aten::slice(%6, %4, %2, %3) # dynamic_bvg.py:50:25 +//! %view_shape.1 : int[] = aten::add(%7, %5) # dynamic_bvg.py:50:25 +//! %bias : Tensor = prim::GetAttr[name="bias"](%self) +//! %10 : int[] = aten::size(%bias) +//! %11 : int[] = prim::BroadcastSizes(%6, %10) +//! %12 : bool = prim::CudaFusionGuard[types=[...]](%inputs.1, %bias) +//! %13 : int[] = prim::Constant[value=[-1, -1, -1, 6]]() +//! %14 : int[] = prim::Constant[value=[-1, -1, -1, 2, 3]]() +//! %15 : bool = prim::CudaFusionViewGuard(%11, %view_shape.1, %13, %14) +//! %16 : bool[] = prim::ListConstruct(%15, %12) +//! %17 : bool = aten::all(%16) +//! %18 : Tensor = prim::If(%17) +//! block0(): +//! %19 : Tensor = prim::CudaFusionGroup_0[cache_id=0](%inputs.1, %bias) +//! -> (%19) +//! block1(): +//! %20 : Function = prim::Constant[name="fallback_fn", fallback=1]() +//! %21 : (...) = prim::CallFunction(%20, %inputs.1, %bias, %view_shape.1) +//! %22 : Float(...) = prim::TupleUnpack(%21) +//! -> (%22) +//! return (%18) +//! with prim::CudaFusionGroup_0 = graph(%0 : Float(...), +//! %1 : Float(...)): +//! %2 : int[] = prim::Constant[value=[2, 3, 4, 2, 3]]() +//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25 +//! %o.1 : Float(...) = aten::add(%0, %1, %3) # dynamic_bvg.py:51:16 +//! %5 : Float(...) = prim::view_copy(%o.1, %2) +//! %6 : Float(...) = aten::relu(%5) # dynamic_bvg.py:53:19 +//! return (%6) +//! +RegisterOperators view_guard({ + Operator( + "prim::CudaFusionViewGuard(...) -> bool", + // prim::CudaFusionViewGuard returns a fresh Boolean type without + // aliasing. if we would ever return refined tensor, which would change + // aliasing analysis, we should update aliasdb pass. + [](const Node* node) -> Operation { + return [](Stack& stack) { + // view_sizes_constraint - Constant List[Int] + at::ArrayRef inputs = last(stack, 4); + + // tensor_sizes is the runtime size for the self tensor + // tensor_sizes - dynamic size List[Int] + TORCH_INTERNAL_ASSERT( + inputs[0].isIntList(), "tensor_sizes needs to be Int List"); + auto tensor_sizes = inputs[0].toIntList(); + + // profiled_view_sizes is the runtime view size + // profiled_view_sizes - profile_ivalue List[Int] + TORCH_INTERNAL_ASSERT( + inputs[1].isIntList(), + "profiled_view_sizes needs to be Int list"); + auto profiled_view_sizes = inputs[1].toIntList(); + + // tensor_constraint is a constant List[Int] + // used to guard tensor_sizes + TORCH_INTERNAL_ASSERT( + inputs[2].isIntList(), + "tensor constraint needs to be Int List"); + auto tensor_constraint = inputs[2].toIntList(); + + // view_sizes_constraint is a constant List[Int] + // used to guard profiled_view_sizes + TORCH_INTERNAL_ASSERT( + inputs[3].isIntList(), + "view_sizes constraint needs to be Int List"); + auto view_sizes_constraint = inputs[3].toIntList(); + + // Drop after gather all input arguments + // If an argument is moved, it is destroyed when dropped from stack + drop(stack, 4); + + auto status = inferViewShape(tensor_sizes, profiled_view_sizes); + if (!status) { + push(stack, IValue(false)); + return; + } + + if (!fuser::cuda::getCudaFusionGuardMode()) { + push(stack, IValue(true)); + return; + } + + auto guard_status = checkViewGuard( + tensor_sizes, + profiled_view_sizes, + tensor_constraint, + view_sizes_constraint); + push(stack, IValue(guard_status)); + return; + }; + }, + aliasAnalysisFromSchema()), +}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) RegisterOperators reg_add_optional({ Operator( @@ -346,6 +560,160 @@ RegisterOperators reg_add_optional({ }, aliasAnalysisFromSchema()), }); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_view_copy({ + Operator( + "prim::view_copy(Tensor self, int[] size) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "view_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, size; + pop(stack, self, size); + push(stack, at::native::view(self.toTensor(), size.toIntVector())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_reshape_copy({ + Operator( + "prim::reshape_copy(Tensor self, int[] shape) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "reshape_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, shape; + pop(stack, self, shape); + push( + stack, + at::native::reshape(self.toTensor(), shape.toIntVector())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_squeeze_copy({ + Operator( + "prim::squeeze_copy(Tensor self) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "squeeze_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self; + pop(stack, self); + push(stack, at::squeeze(self.toTensor())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_squeeze_dim_copy({ + Operator( + "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "squeeze_dim_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, dim; + pop(stack, self, dim); + push(stack, at::squeeze(self.toTensor(), dim.toInt())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_unsqueeze_copy({ + Operator( + "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "unsqueeze_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, dim; + pop(stack, self, dim); + push(stack, at::unsqueeze(self.toTensor(), dim.toInt())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_infer_unsqueeze_size({ + Operator( + "prim::infer_unsqueeze_size(int[] a, int dim) -> int[]", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto dim = pop(stack).toInt(); + auto size = pop(stack).toIntVector(); + if (dim < 0) { + dim = dim + 1 + size.size(); + } + auto it = size.begin() + dim; + size.insert(it, 1); + push(stack, IValue(size)); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_infer_squeeze_dim_size({ + Operator( + "prim::infer_squeeze_size(int[] a, int dim) -> int[]", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto dim = pop(stack).toInt(); + auto size = pop(stack).toIntVector(); + if (dim < 0) { + dim = dim + size.size(); + } + auto it = size.begin() + dim; + if (*it == 1) { + size.erase(it); + } + push(stack, IValue(size)); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_infer_squeeze_size({ + Operator( + "prim::infer_squeeze_size.dim(int[] a) -> int[]", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto size = pop(stack).toIntVector(); + + for (auto it = size.begin(); it != size.end(); it++) { + if (*it == 1) { + auto pre = it - 1; + size.erase(it); + it = pre; + } + } + push(stack, IValue(size)); + }; + }, + aliasAnalysisFromSchema()), +}); + } // namespace } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/interface.h b/torch/csrc/jit/codegen/cuda/interface.h index 1ab9e6d8008..8afa854ea5c 100644 --- a/torch/csrc/jit/codegen/cuda/interface.h +++ b/torch/csrc/jit/codegen/cuda/interface.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index cf3d9c7a8c7..6a094c104df 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -1,8 +1,12 @@ #include #include #include +#include #include #include +#include +#include +#include #include #include @@ -20,16 +24,20 @@ namespace jit { namespace fuser { namespace cuda { +Statement::Statement(IrBuilderPasskey passkey) { + ir_container_ = passkey.ir_container_; +} + Statement::Statement(const Statement* src, IrCloner* ir_cloner) { - // IRCloner when cloning to a new fusion will copy the names of the original - // fusion. If we're cloning into the same fusion, we let Val and Expr get - // their names as usual by registering with the current fusion in their - // constructors, so don't overwrite that here. - if (src->fusion() != ir_cloner->fusion()) { - name_ = src->name_; - } - fusion_ = ir_cloner->fusion(); - ir_cloner->registerClone(src, this); + ir_container_ = ir_cloner->container(); +} + +void Statement::setName(IrContainerPasskey, StmtNameType name) { + name_ = name; +} + +void Statement::setName(IrBuilderPasskey, StmtNameType name) { + name_ = name; } Val* Statement::asVal() { @@ -42,24 +50,37 @@ Expr* Statement::asExpr() { return this->as(); } -void Statement::print() const { - IrPrinter ir_printer(std::cout); +std::string Statement::toString() const { + std::stringstream ss; + IrPrinter ir_printer(ss); ir_printer.handle(this); - std::cout << std::endl; + return ss.str(); } -// When we create a Val we immediately register them with the active fusion. -Val::Val(ValType _vtype, DataType _dtype, bool register_val) - : vtype_(_vtype), dtype_(_dtype) { - Fusion* fusion = FusionGuard::getCurFusion(); - TORCH_CHECK( - fusion != nullptr, "No active fusion group found when creating a Val."); - fusion_ = fusion; - if (register_val) { - name_ = fusion_->registerVal(this); - } +std::string Statement::toInlineString() const { + std::stringstream ss; + IrPrinter ir_printer(ss); + ir_printer.print_inline(this); + return ss.str(); } +Fusion* Statement::fusion() const { + TORCH_INTERNAL_ASSERT( + ir_container_->isA(), "Statement does not belong to a fusion."); + return ir_container_->as(); +} + +kir::Kernel* Statement::kernel() const { + TORCH_INTERNAL_ASSERT( + ir_container_->isA(), + "Statement does not belong to a kernel."); + return ir_container_->as(); +} + +// When we create a Val we immediately register them with the active fusion. +Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype) + : Statement(passkey), vtype_(_vtype), dtype_(_dtype) {} + // NOTE: we don't clone the definition_ and uses_ here // since they may introduce cloning cycles. Instead, we copy // the original pointers and we'll fix them up later part of the @@ -71,12 +92,7 @@ Val::Val(const Val* src, IrCloner* ir_cloner) vtype_(src->vtype_), dtype_(src->dtype_), is_fusion_input_(src->is_fusion_input_), - is_fusion_output_(src->is_fusion_output_) { - // If we're "cloning" into the same fusion, register with the fusion - if (src->fusion() == ir_cloner->fusion()) { - name_ = src->fusion()->registerVal(this); - } -} + is_fusion_output_(src->is_fusion_output_) {} const std::vector& Val::uses() const { if (vtype_ == ValType::TensorView) { @@ -92,33 +108,33 @@ namespace { // Traverse definition of all values involved in constructing the provided val. // Check if all values involved are constant values, meaning the provided // val is also a constant value. -class ConstCheck : OptOutConstDispatch { +class ConstCheck : private OptOutConstDispatch { private: bool is_const_ = true; - void handle(const Bool* b) override { + void handle(const Bool* b) final { is_const_ = is_const_ && b->isConst(); } - void handle(const Double* d) override { + void handle(const Double* d) final { is_const_ = is_const_ && d->isConst(); } - void handle(const Int* i) override { + void handle(const Int* i) final { is_const_ = is_const_ && i->isConst(); } - void handle(const NamedScalar* ns) override { + void handle(const NamedScalar* ns) final { is_const_ = is_const_ && false; } - void handle(const Expr* expr) override { + void handle(const Expr* expr) final { for (auto inp : expr->inputs()) { handle(inp); } } - void handle(const Val* val) override { + void handle(const Val* val) final { if (val->definition() != nullptr) { handle(val->definition()); } else { @@ -137,15 +153,18 @@ class ConstCheck : OptOutConstDispatch { } // namespace bool Val::isConstScalar() const { - if (!isScalar()) + if (!isScalar()) { return false; + } return ConstCheck::isConst(this); } c10::optional Val::getInt() const { if (isConstScalar() && isAnInt()) { if (this->getValType() == ValType::Scalar) { - return this->as()->value(); + if (this->isA()) { + return this->as()->value(); + } } } return c10::optional(); @@ -169,7 +188,7 @@ c10::optional Val::getDataType() const { bool Val::isProducerOf(const Val* other) const { TORCH_INTERNAL_ASSERT(other != nullptr); - TORCH_INTERNAL_ASSERT(fusion() == other->fusion()); + TORCH_INTERNAL_ASSERT(container() == other->container()); if (definition() == nullptr) { return false; @@ -186,23 +205,14 @@ bool Val::isConsumerOf(const Val* other) const { // We don't register with the active fusion in Expr as this needs to be done // after inputs and outputs are registered with the Expr -Expr::Expr(ExprType type) : type_{type} { - Fusion* fusion = FusionGuard::getCurFusion(); - if (fusion == nullptr) - TORCH_CHECK(false, "No active fusion group found when creating an Expr."); - fusion_ = fusion; -} +Expr::Expr(IrBuilderPasskey passkey, ExprType etype) + : Statement(passkey), etype_{etype} {} Expr::Expr(const Expr* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), - type_(src->type_), + etype_(src->etype_), inputs_(ir_cloner->clone(src->inputs_)), - outputs_(ir_cloner->clone(src->outputs_)) { - // If we're "cloning" into the same fusion, register with the fusion - if (src->fusion() == ir_cloner->fusion()) { - name_ = src->fusion()->registerExpr(this); - } -} + outputs_(ir_cloner->clone(src->outputs_)) {} bool Expr::sameAs(const Statement* other) const { if (this == other) { @@ -227,6 +237,30 @@ bool Expr::sameAs(const Statement* other) const { return true; } +kir::Predicate* Expr::predicate() const { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + return predicate_; +} + +void Expr::setPredicate(kir::Predicate* predicate) { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + predicate_ = predicate; +} + +kir::Predicate* Expr::writePredicate() const { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + return write_predicate_; +} + +void Expr::setWritePredicate(kir::Predicate* write_predicate) { + TORCH_INTERNAL_ASSERT( + container()->isA(), "Function invalid for fusion."); + write_predicate_ = write_predicate; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 2e0fa0885bd..1b8444fae46 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #include #include @@ -35,6 +35,8 @@ namespace jit { namespace fuser { namespace cuda { +using ValueId = int32_t; + using StmtNameType = unsigned int; constexpr StmtNameType kInvalidStmName = @@ -48,6 +50,22 @@ class UnaryOp; class BinaryOp; class IterDomain; class IrCloner; +class IrContainer; +class IrBuilderPasskey; +class IrContainerPasskey; + +namespace kir { +class Kernel; +class Predicate; +} // namespace kir + +// Passkey for container to register names with statements +class ExprPasskey { + friend class Expr; + + private: + explicit ExprPasskey() {} +}; TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; @@ -60,12 +78,12 @@ TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; //! is also important for the design to have a dispatch system for a Statment. //! Basically beinng able to succienctly traverse down the inhereitance stack of //! a Statment at runtime. This is currently implemented in dispatch.h -//! class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { friend void swap(Fusion&, Fusion&) noexcept; + friend void swap(IrContainer& a, IrContainer& b) noexcept; public: - Statement() = default; + Statement() = delete; // Cloning constructor Statement(const Statement* src, IrCloner* ir_cloner); @@ -78,7 +96,7 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { static void constDispatch(T handler, const Statement* const); template - static Statement* mutatorDispatch(T mutator, Statement*); + static void mutatorDispatch(T mutator, Statement*); // Accessor functions to types. Vals always have a DataType, Exprs never do virtual c10::optional getValType() const { @@ -106,8 +124,14 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { Expr* asExpr(); // Return the fusion this statement belongs to - Fusion* fusion() const { - return fusion_; + Fusion* fusion() const; + + // Return the kernel this statement belongs to + kir::Kernel* kernel() const; + + // Return the container this statement belongs to + IrContainer* container() const { + return ir_container_; } // Return the int that represents its name @@ -115,6 +139,13 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { return name_; } + // Set the statements' name. Typically the container will set the name, + // however if we're dealing with cloning, IrBuilder will set the name, this + // maybe should be from IrCloner, however I didn't want to add another + // passkey. + void setName(IrContainerPasskey, StmtNameType name); + void setName(IrBuilderPasskey, StmtNameType name); + virtual bool sameType(const Statement* const other) { if (isVal() && other->isVal()) return getValType().value() == other->getValType().value(); @@ -129,13 +160,17 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { return this == other; } - void print() const; + std::string toString() const; + std::string toInlineString() const; protected: + Statement(IrBuilderPasskey); + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) StmtNameType name_ = kInvalidStmName; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Fusion* fusion_ = nullptr; + IrContainer* ir_container_ = nullptr; }; //! A Val represents a "value." These are objects, like tensors, scalars, and @@ -169,34 +204,43 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { //! class TORCH_CUDA_CU_API Val : public Statement { public: - // We may not want to register this value during Val's constructor. The reason - // for this is that if we register the val, then in a derived constructor try - // to throw, fusion's destructor will get called, but the pointer to this Val - // will be invalid. When fusion tries to delete this value it will cause a seg - // fault, instead of showing the thrown error. explicit Val( + IrBuilderPasskey, ValType _vtype, - DataType _dtype = DataType::Null, - bool register_val = true); + DataType _dtype = DataType::Null); Val(const Val* src, IrCloner* ir_cloner); - // TODO: why is this optional? - // + // Dispatch functions, definitions in dispatch.cpp + template + static void dispatch(T handler, Val*); + + template + static void constDispatch(T handler, const Val* const); + + template + static void mutatorDispatch(T mutator, Val*); + c10::optional getValType() const override { return vtype_; } + ValType vtype() const { + return vtype_; + } + + DataType dtype() const { + return dtype_; + } + // Throws if no DataType is found. Vals must have a DataType - // - // TODO: why is this optional? - // c10::optional getDataType() const override; bool isScalar() const { return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar; } + // Returns if all dependencies are constant scalars bool isConstScalar() const; bool isAnInt() const { @@ -205,6 +249,11 @@ class TORCH_CUDA_CU_API Val : public Statement { c10::optional getInt() const; + // Returns if no dependencies and is a constant scalar. + virtual bool isConst() const { + return false; + } + bool isZeroInt() const; bool isOneInt() const; @@ -254,15 +303,11 @@ class TORCH_CUDA_CU_API Val : public Statement { return evaluator_index_; } - // Dispatch functions, definitions in dispatch.cpp - template - static void dispatch(T handler, Val*); - - template - static void constDispatch(T handler, const Val* const); - - template - static Statement* mutatorDispatch(T mutator, Val*); + // Following is managed by Fusion (or kirIrBuilder) and can change. + // TODO: Protect with a passkey. + void setDefinition(Expr* expr) { + definition_ = expr; + } protected: friend Fusion; @@ -272,19 +317,17 @@ class TORCH_CUDA_CU_API Val : public Statement { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const DataType dtype_; - // Following is managed by Fusion and can change. - void setDefinition(Expr* expr) { - definition_ = expr; - } - + // TODO: Add fusion passkey for this void setIsFusionInput(bool is_fusion_input) { is_fusion_input_ = is_fusion_input; } + // TODO: Add fusion passkey for this void setIsFusionOutput(bool is_fusion_output) { is_fusion_output_ = is_fusion_output; } + // TODO: Add fusion or container passkey for this void setUses(const std::vector& uses) { uses_ = uses; } @@ -297,6 +340,7 @@ class TORCH_CUDA_CU_API Val : public Statement { Expr* definition_ = nullptr; std::vector uses_; + // Expr evaluator idx; int evaluator_index_ = -1; }; @@ -342,15 +386,16 @@ class TORCH_CUDA_CU_API Val : public Statement { //! class TORCH_CUDA_CU_API Expr : public Statement { public: - explicit Expr(ExprType type); + explicit Expr(IrBuilderPasskey, ExprType type); + Expr(const Expr* src, IrCloner* ir_cloner); c10::optional getExprType() const override { - return type_; + return etype_; } - ExprType type() const { - return type_; + ExprType etype() const { + return etype_; } bool sameAs(const Statement* other) const override; @@ -380,23 +425,46 @@ class TORCH_CUDA_CU_API Expr : public Statement { static void constDispatch(T handler, const Expr* const); template - static Statement* mutatorDispatch(T mutator, Expr*); + static void mutatorDispatch(T mutator, Expr*); + + // TODO: Protect based on being in kernel container + kir::Predicate* predicate() const; + + // TODO: Protect based on being in kernel container + void setPredicate(kir::Predicate* predicate); + + // TODO: Protect based on being in kernel container + kir::Predicate* writePredicate() const; + + // TODO: Protect based on being in kernel container + void setWritePredicate(kir::Predicate* write_predicate); protected: + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); inputs_.push_back(input); } + // TODO: Add Fusion passkey void addOutput(Val* output) { TORCH_INTERNAL_ASSERT(output != nullptr); outputs_.push_back(output); } + ExprPasskey exprPasskey() { + return ExprPasskey(); + } + private: - ExprType type_ = ExprType::Invalid; + ExprType etype_ = ExprType::Invalid; std::vector inputs_; std::vector outputs_; + + kir::Predicate* predicate_ = nullptr; + + // Only used for reduction-related expressions + kir::Predicate* write_predicate_ = nullptr; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp similarity index 50% rename from torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp rename to torch/csrc/jit/codegen/cuda/ir_builder.cpp index ce3e17d74d2..17a4e59cfb6 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -1,35 +1,97 @@ -#include +#include +#include +#include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { -namespace kir { + +//! Clone an IR node, forwarding the arguments to the IrCloner constructor. +template +T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { + TORCH_INTERNAL_ASSERT( + ir_cloner != nullptr, + "Cannot use create when a cloner object is set. Use clone."); + + TORCH_INTERNAL_ASSERT( + ir_cloner->container() != nullptr, + "Cloner doesn't have a valid container to store cloned object."); + + T* dest = new T(src, ir_cloner); + const Statement* src_stmt = dynamic_cast(src); + Statement* dest_stmt = dynamic_cast(dest); + + auto dest_container = ir_cloner->container(); + auto src_container = src_stmt->container(); + + dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + + if (src_container != dest_container) { + dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); + } + + ir_cloner->registerClone(src_stmt, dest_stmt); + + return dest; +} + +#define IR_BUILDER_INSTANTIATE(T) \ + template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner); + +// Vals +IR_BUILDER_INSTANTIATE(IterDomain) +IR_BUILDER_INSTANTIATE(TensorDomain) +IR_BUILDER_INSTANTIATE(TensorView) +IR_BUILDER_INSTANTIATE(Bool) +IR_BUILDER_INSTANTIATE(Double) +IR_BUILDER_INSTANTIATE(Int) +IR_BUILDER_INSTANTIATE(NamedScalar) + +// Exprs +IR_BUILDER_INSTANTIATE(Split) +IR_BUILDER_INSTANTIATE(Merge) +IR_BUILDER_INSTANTIATE(TransposeOp) +IR_BUILDER_INSTANTIATE(ShiftOp) +IR_BUILDER_INSTANTIATE(GatherOp) +IR_BUILDER_INSTANTIATE(ViewOp) +IR_BUILDER_INSTANTIATE(UnaryOp) +IR_BUILDER_INSTANTIATE(BinaryOp) +IR_BUILDER_INSTANTIATE(TernaryOp) +IR_BUILDER_INSTANTIATE(ReductionOp) +IR_BUILDER_INSTANTIATE(WelfordOp) +IR_BUILDER_INSTANTIATE(BroadcastOp) Val* IrBuilder::newResult(DataType dtype) { switch (dtype) { case DataType::Bool: - return create(c10::nullopt); + return IrBuilder::create(c10::nullopt); case DataType::Double: - return create(c10::nullopt); + return IrBuilder::create(c10::nullopt); case DataType::Int: - return create(c10::nullopt); + return IrBuilder::create(c10::nullopt); default: TORCH_CHECK(false, "Unexpected data type"); } } Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { - TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types"); + TORCH_CHECK( + lhs->dtype() == rhs->dtype(), + "Incompatible operand types: ", + lhs->dtype(), + " and ", + rhs->dtype()); auto result = newResult(lhs->dtype()); - create(op_type, result, lhs, rhs); + IrBuilder::create(op_type, result, lhs, rhs); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return result; } Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { - auto result = create(c10::nullopt); - create(op_type, result, lhs, rhs); + auto result = IrBuilder::create(c10::nullopt); + IrBuilder::create(op_type, result, lhs, rhs); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return result; } @@ -37,37 +99,37 @@ Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) { Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) { TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types"); auto result = newResult(lhs->dtype()); - create(TernaryOpType::Where, result, pred, lhs, rhs); + IrBuilder::create(TernaryOpType::Where, result, pred, lhs, rhs); return result; } Val* IrBuilder::negExpr(Val* val) { auto result = newResult(val->dtype()); - create(UnaryOpType::Neg, result, val); + IrBuilder::create(UnaryOpType::Neg, result, val); return result; } Val* IrBuilder::notExpr(Val* val) { auto result = newResult(val->dtype()); - create(UnaryOpType::Not, result, val); + IrBuilder::create(UnaryOpType::Not, result, val); return result; } Val* IrBuilder::setExpr(Val* val) { auto result = newResult(val->dtype()); - create(UnaryOpType::Set, result, val); + IrBuilder::create(UnaryOpType::Set, result, val); return result; } Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) { - auto result = create(name, val->dtype()); - create(UnaryOpType::Set, result, val); + auto result = IrBuilder::create(name, val->dtype()); + IrBuilder::create(UnaryOpType::Set, result, val); return result; } Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) { - auto result = create(name, DataType::Int); - create(UnaryOpType::Address, result, val); + auto result = IrBuilder::create(name, DataType::Int); + IrBuilder::create(UnaryOpType::Address, result, val); return result; } @@ -127,45 +189,10 @@ Val* IrBuilder::minExpr(Val* lhs, Val* rhs) { return newArithmeticExpr(BinaryOpType::Min, lhs, rhs); } -Int* IrBuilder::zeroVal() { - if (zero_ == nullptr) { - zero_ = create(0); - } - return zero_; -} - -Int* IrBuilder::oneVal() { - if (one_ == nullptr) { - one_ = create(1); - } - return one_; -} - -Bool* IrBuilder::falseVal() { - if (false_ == nullptr) { - false_ = create(false); - } - return false_; -} - -Bool* IrBuilder::trueVal() { - if (true_ == nullptr) { - true_ = create(true); - } - return true_; -} - -NamedScalar* IrBuilder::magicZeroVal() { - if (magic_zero_ == nullptr) { - magic_zero_ = create(kMagicZeroName, DataType::Int); - } - return magic_zero_; -} - Val* SimplifyingIrBuilder::negExpr(Val* val) { - if (auto int_val = dynamic_cast(val)) { + if (auto int_val = dynamic_cast(val)) { if (int_val->isConst()) { - return create(-int_val->value().value()); + return IrBuilder::create(-int_val->value().value()); } } return IrBuilder::negExpr(val); @@ -175,9 +202,9 @@ Val* SimplifyingIrBuilder::notExpr(Val* val) { if (auto bool_val = dynamic_cast(val)) { if (bool_val->isConst()) { if (bool_val->value().value()) { - return falseVal(); + return FusionGuard::getCurFusion()->falseVal(); } else { - return trueVal(); + return FusionGuard::getCurFusion()->trueVal(); } } } @@ -188,13 +215,13 @@ Val* SimplifyingIrBuilder::addExpr(Int* lhs, Int::ScalarType rhs) { if (rhs == 0) { return lhs; } else if (lhs == nullptr) { - return IrBuilder::create(rhs); + return IrBuilder::IrBuilder::create(rhs); } else if (lhs->isConst()) { - return IrBuilder::create(lhs->value().value() + rhs); + return IrBuilder::IrBuilder::create(lhs->value().value() + rhs); } else if (rhs > 0) { - return IrBuilder::addExpr(lhs, IrBuilder::create(rhs)); + return IrBuilder::addExpr(lhs, IrBuilder::IrBuilder::create(rhs)); } else { - return IrBuilder::subExpr(lhs, IrBuilder::create(-rhs)); + return IrBuilder::subExpr(lhs, IrBuilder::IrBuilder::create(-rhs)); } } @@ -228,6 +255,15 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } } +Val* SimplifyingIrBuilder::addExpr(Val* lhs, Int::ScalarType rhs) { + auto lhs_int = dynamic_cast(lhs); + if (lhs_int != nullptr) { + return addExpr(lhs_int, rhs); + } else { + return addExpr(lhs, IrBuilder::create(rhs)); + } +} + Val* SimplifyingIrBuilder::subExpr(Val* lhs, Val* rhs) { return addExpr(lhs, negExpr(rhs)); } @@ -257,9 +293,9 @@ Val* SimplifyingIrBuilder::andExpr(Val* lhs, Val* rhs) { } if (lhs_definitely_true && rhs_definitely_true) { - return trueVal(); + return FusionGuard::getCurFusion()->trueVal(); } else if (lhs_definitely_false || rhs_definitely_false) { - return falseVal(); + return FusionGuard::getCurFusion()->falseVal(); } else if (lhs_definitely_true) { return rhs; } else if (rhs_definitely_true) { @@ -269,7 +305,65 @@ Val* SimplifyingIrBuilder::andExpr(Val* lhs, Val* rhs) { return IrBuilder::andExpr(lhs, rhs); } -} // namespace kir +namespace { + +template +Val* minOrMaxExpr( + Int* lhs, + Int* rhs, + IrBuilderFunc ir_builder_func, + IntFunc int_func) { + if (rhs == nullptr) { + return lhs; + } else if (lhs == nullptr) { + return rhs; + } else if (lhs->isConst() && rhs->isConst()) { + return IrBuilder::create( + int_func(lhs->value().value(), rhs->value().value())); + } else { + return ir_builder_func(lhs, rhs); + } +} + +template +Val* minOrMaxExpr( + Val* lhs, + Val* rhs, + IrBuilderFunc ir_builder_func, + IntFunc int_func) { + TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr); + if (lhs == nullptr) { + return rhs; + } else if (rhs == nullptr || lhs == rhs) { + return lhs; + } + auto lhs_int = dynamic_cast(lhs); + auto rhs_int = dynamic_cast(rhs); + if (lhs_int != nullptr && rhs_int != nullptr) { + return minOrMaxExpr(lhs_int, rhs_int, ir_builder_func, int_func); + } else { + return ir_builder_func(lhs, rhs); + } +} + +} // namespace + +Val* SimplifyingIrBuilder::maxExpr(Val* lhs, Val* rhs) { + return minOrMaxExpr( + lhs, + rhs, + [](Val* lhs, Val* rhs) { return IrBuilder::maxExpr(lhs, rhs); }, + [](int64_t lhs, int64_t rhs) { return std::max(lhs, rhs); }); +} + +Val* SimplifyingIrBuilder::minExpr(Val* lhs, Val* rhs) { + return minOrMaxExpr( + lhs, + rhs, + [](Val* lhs, Val* rhs) { return IrBuilder::minExpr(lhs, rhs); }, + [](int64_t lhs, int64_t rhs) { return std::min(lhs, rhs); }); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.h b/torch/csrc/jit/codegen/cuda/ir_builder.h new file mode 100644 index 00000000000..5087f2832a9 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_builder.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace kir { +class Kernel; +} + +class IrCloner; + +// Passkey for builder to register properties with statements, and to call +// functions in IrContainer +class TORCH_CUDA_CU_API IrBuilderPasskey { + friend class IrBuilder; + + public: + // TODO: Collapse ir_container and Kernel once Kernel inherits from + // IrContainer + IrContainer* const ir_container_ = nullptr; + + private: + explicit IrBuilderPasskey(IrContainer* ir_container); +}; + +//! IR builder interface +class TORCH_CUDA_CU_API IrBuilder { + public: + //! Allocate a new IR node, forwarding the arguments to the appropriate + //! constructor and registering with the container + template + static T* create(Args&&... args) { + auto container = FusionGuard::getCurFusion(); + // return create(container, std::forward(args)...); + TORCH_INTERNAL_ASSERT( + container != nullptr, "Need an active container to build IR."); + T* node = new T(IrBuilderPasskey(container), std::forward(args)...); + + container->registerStmt(IrBuilderPasskey(container), node); + + return node; + } + + //! Allocate a new IR node, forwarding the arguments to the appropriate + //! constructor and registering with the container + template + static T* create(IrContainer* container, Args&&... args) { + TORCH_INTERNAL_ASSERT( + container != nullptr, "Need an active container to build IR."); + T* node = new T(IrBuilderPasskey(container), std::forward(args)...); + + container->registerStmt(IrBuilderPasskey(container), node); + + return node; + } + + //! Clone an IR node, forwarding the arguments to the IrCloner constructor. + //! Register clones with IrCloner's target container. + template + static T* clone(const T* src, IrCloner* ir_cloner); + + // Unary operations + static Val* negExpr(Val* val); + static Val* notExpr(Val* val); + static Val* setExpr(Val* val); + static Val* setExprNamedScalar(const std::string& name, Val* val); + static Val* addressExprNamedScalar(const std::string& name, Val* val); + + // Binary operations + static Val* andExpr(Val* lhs, Val* rhs); + static Val* eqExpr(Val* lhs, Val* rhs); + static Val* gtExpr(Val* lhs, Val* rhs); + static Val* ltExpr(Val* lhs, Val* rhs); + static Val* leExpr(Val* lhs, Val* rhs); + static Val* geExpr(Val* lhs, Val* rhs); + static Val* addExpr(Val* lhs, Val* rhs); + static Val* subExpr(Val* lhs, Val* rhs); + static Val* mulExpr(Val* lhs, Val* rhs); + static Val* divExpr(Val* lhs, Val* rhs); + static Val* ceilDivExpr(Val* lhs, Val* rhs); + static Val* modExpr(Val* lhs, Val* rhs); + static Val* maxExpr(Val* lhs, Val* rhs); + static Val* minExpr(Val* lhs, Val* rhs); + + // Ternary operations + static Val* whereExpr(Val* pred, Val* lhs, Val* rhs); + + private: + static Val* newResult(DataType dtype); + static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs); + static Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs); +}; + +//! A wrapper builder with static expression simplification +//! +//! Example: +//! - addExpr(new Int(1), new Int(2)) -> Int(3) +//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo") +//! +//! Designed to be used to simplify predicate and index expressions in +//! generated code. Also, the shift validation may fail without +//! this simplification. +class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { + public: + static Val* negExpr(Val* val); + static Val* notExpr(Val* val); + + static Val* addExpr(Int* lhs, Int::ScalarType rhs); + static Val* addExpr(Val* lhs, Int::ScalarType rhs); + static Val* addExpr(Int* lhs, Int* rhs); + static Val* addExpr(Val* lhs, Val* rhs); + static Val* subExpr(Val* lhs, Val* rhs); + static Val* andExpr(Val* lhs, Val* rhs); + static Val* maxExpr(Val* lhs, Val* rhs); + static Val* minExpr(Val* lhs, Val* rhs); +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 7e5a9cfa8bc..8a1717e8d05 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -2,12 +2,15 @@ #include #include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { +IrCloner::IrCloner(IrContainer* container) : ir_container_(container) {} + Statement* IrCloner::clone(const Statement* statement) { if (statement == nullptr) { return nullptr; @@ -30,7 +33,6 @@ Statement* IrCloner::clone(const Statement* statement) { // that something went horribly wrong. TORCH_INTERNAL_ASSERT(new_node != nullptr); TORCH_INTERNAL_ASSERT(clones_map_[statement] == new_node); - TORCH_INTERNAL_ASSERT(new_node->fusion() == fusion_); return new_node; } @@ -39,7 +41,6 @@ Statement* IrCloner::clone(const Statement* statement) { void IrCloner::registerClone(const Statement* src, Statement* clone) { TORCH_CHECK(src != nullptr); TORCH_CHECK(clone != nullptr); - TORCH_CHECK(clone->fusion() == fusion_); TORCH_CHECK(clones_map_.insert({src, clone}).second); } @@ -56,79 +57,79 @@ void IrCloner::handle(const Expr* e) { } void IrCloner::handle(const TensorDomain* td) { - clone_ = new TensorDomain(td, this); + clone_ = IrBuilder::clone(td, this); } void IrCloner::handle(const IterDomain* id) { - clone_ = new IterDomain(id, this); + clone_ = IrBuilder::clone(id, this); } void IrCloner::handle(const Bool* b) { - clone_ = new Bool(b, this); + clone_ = IrBuilder::clone(b, this); } void IrCloner::handle(const Double* d) { - clone_ = new Double(d, this); + clone_ = IrBuilder::clone(d, this); } void IrCloner::handle(const Int* i) { - clone_ = new Int(i, this); + clone_ = IrBuilder::clone(i, this); } void IrCloner::handle(const NamedScalar* named_scalar) { - clone_ = new NamedScalar(named_scalar, this); + clone_ = IrBuilder::clone(named_scalar, this); } void IrCloner::handle(const TensorView* tv) { - clone_ = new TensorView(tv, this); + clone_ = IrBuilder::clone(tv, this); } void IrCloner::handle(const UnaryOp* op) { - clone_ = new UnaryOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const BinaryOp* op) { - clone_ = new BinaryOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const TernaryOp* op) { - clone_ = new TernaryOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const BroadcastOp* op) { - clone_ = new BroadcastOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const ReductionOp* op) { - clone_ = new ReductionOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const WelfordOp* op) { - clone_ = new WelfordOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const TransposeOp* op) { - clone_ = new TransposeOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const ShiftOp* op) { - clone_ = new ShiftOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const GatherOp* op) { - clone_ = new GatherOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const ViewOp* op) { - clone_ = new ViewOp(op, this); + clone_ = IrBuilder::clone(op, this); } void IrCloner::handle(const Split* split) { - clone_ = new Split(split, this); + clone_ = IrBuilder::clone(split, this); } void IrCloner::handle(const Merge* merge) { - clone_ = new Merge(merge, this); + clone_ = IrBuilder::clone(merge, this); } TensorView* RecomputeTv::recompute(TensorView* tv) { @@ -141,7 +142,7 @@ TensorView* RecomputeTv::recompute(TensorView* tv) { "Cannot recompute buffers that are inputs of the fusion."); // Grab all the expressions used to generate the TensorView - auto exprs = ExprSort::getExprs(tv->fusion(), {tv}); + auto exprs = StmtSort::getExprs(tv->fusion(), {tv}, false); // Run the replicator RecomputeTv replicator(tv->fusion(), exprs); @@ -161,7 +162,7 @@ TensorView* RecomputeTv::recompute(TensorView* tv) { } RecomputeTv::RecomputeTv(Fusion* fusion, std::vector exprs) - : IrCloner(fusion) { + : IrCloner(fusion), fusion_(fusion) { // Add inputs to the clones map to prevent cloning them. for (const auto inp : fusion->inputs()) { clones_map_[inp] = inp; @@ -183,7 +184,7 @@ void RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. auto exprs = - ExprSort::getExprs(fusion(), {td->domain().begin(), td->domain().end()}); + StmtSort::getExprs(fusion_, {td->domain().begin(), td->domain().end()}); for (auto expr : exprs) { IrCloner::handle(expr); diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index ac83d9edb09..1755b9e9563 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include #include #include @@ -11,7 +12,7 @@ namespace jit { namespace fuser { namespace cuda { -class Fusion; +class IrContainer; //! Clones nodes from an exiting Fusion //! @@ -21,10 +22,11 @@ class Fusion; //! class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { friend class Statement; + friend class IrBuilder; public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit IrCloner(Fusion* new_fusion) : fusion_(new_fusion) {} + explicit IrCloner(IrContainer* container); Statement* clone(const Statement* statement); @@ -45,8 +47,8 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { return copy; } - Fusion* fusion() const { - return fusion_; + IrContainer* container() const { + return ir_container_; } protected: @@ -86,12 +88,15 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { private: // The destination Fusion container - Fusion* fusion_ = nullptr; + IrContainer* ir_container_ = nullptr; // The dispatch interface doesn't allow returning values from // individual `handle()` methods, so they are storing the // result here Statement* clone_ = nullptr; + + // Builder to make all the new nodes + IrBuilder builder_; }; // Replicates all expressions used to generate the provided TensorView. Does not @@ -105,7 +110,9 @@ class RecomputeTv : private IrCloner { private: RecomputeTv(Fusion* fusion, std::vector exprs); - void handle(const TensorDomain*) override; + void handle(const TensorDomain*) final; + + Fusion* fusion_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ir_container.cpp b/torch/csrc/jit/codegen/cuda/ir_container.cpp new file mode 100644 index 00000000000..e84418eb973 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_container.cpp @@ -0,0 +1,279 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +void swap(IrContainer& a, IrContainer& b) noexcept { + FUSER_PERF_SCOPE("Fusion swap"); + + using std::swap; + + // Swap the content + swap(a.vals_up_, b.vals_up_); + swap(a.vals_, b.vals_); + + swap(a.exprs_up_, b.exprs_up_); + swap(a.exprs_, b.exprs_); + + swap(a.raw_ptrs_, b.raw_ptrs_); + + swap(a.val_type_name_map_, b.val_type_name_map_); + swap(a.expr_name_counter_, b.expr_name_counter_); + + // Fixup the Statement::fusion_ links for a + for (auto val : a.vals_) { + val->ir_container_ = &a; + } + for (auto expr : a.exprs_) { + expr->ir_container_ = &a; + } + + // Fixup the Statement::fusion_ links for b + for (auto val : b.vals_) { + val->ir_container_ = &a; + } + for (auto expr : b.exprs_) { + expr->ir_container_ = &a; + } +} + +IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { + to->clear(); + IrCloner ir_cloner(to); + + for (auto val : from->vals_) { + to->vals_.insert(ir_cloner.clone(val)); + } + + for (auto expr : from->exprs_) { + to->exprs_.insert(ir_cloner.clone(expr)); + } + + to->val_type_name_map_ = from->val_type_name_map_; + to->expr_name_counter_ = from->expr_name_counter_; + + return ir_cloner; +} + +IrContainer::IrContainer() = default; + +IrContainer::IrContainer(const IrContainer& other) { + FUSER_PERF_SCOPE("IrContainer copy"); + IrContainer::copy(&other, this); +} + +IrContainer::IrContainer(IrContainer&& other) noexcept { + FUSER_PERF_SCOPE("IrContainer move"); + swap(*this, other); +} + +IrContainer& IrContainer::operator=(const IrContainer& other) { + FUSER_PERF_SCOPE("IrContainer copy assign"); + IrContainer copy(other); + clear(); + swap(*this, copy); + return *this; +} + +IrContainer& IrContainer::operator=(IrContainer&& other) noexcept { + FUSER_PERF_SCOPE("IrContainer move assign"); + clear(); + swap(*this, other); + return *this; +} + +IrContainer::~IrContainer() { + clear(); +} + +//! Register the Statement with this container +void IrContainer::registerStmt(IrBuilderPasskey, Statement* stmt) { + if (stmt->isVal()) { + registerVal(stmt->asVal()); + } else { + registerExpr(stmt->asExpr()); + } +} + +//! Register the Val with this container +void IrContainer::registerVal(IrBuilderPasskey, Val* val) { + registerVal(val); +} + +//! Register expr with this container. +void IrContainer::registerExpr(IrBuilderPasskey, Expr* expr) { + registerExpr(expr); +} + +void IrContainer::registerExpr(ExprPasskey, Expr* expr) { + registerExpr(expr); +} + +void IrContainer::removeExpr(Expr* expr) { + TORCH_INTERNAL_ASSERT( + exprs_.find(expr) != exprs_.end(), + "Wanted to remove an expression but it doesn't exist in this container."); + auto expr_in_deque = std::find_if( + exprs_up_.begin(), + exprs_up_.end(), + [expr](std::unique_ptr& expr_up) { return expr_up.get() == expr; }); + + TORCH_INTERNAL_ASSERT( + expr_in_deque != exprs_up_.end(), + "Wanted to remove an expression but its unique ptr is missing."); + + exprs_.erase(expr); + exprs_up_.erase(expr_in_deque); + raw_ptrs_.erase((void*)expr); +} + +//! Completely remove val from the fusion, break all dependencies associated +//! with it +void IrContainer::removeVal(Val* val) { + // Don't remove shortcuts + if (val == true_val_.get() || val == false_val_.get() || + val == one_val_.get() || val == zero_val_.get() || + val == magic_zero_val_.get()) { + return; + } + + TORCH_INTERNAL_ASSERT( + vals_.find(val) != vals_.end(), + "Wanted to remove a value but it doesn't exist in this container."); + auto val_in_deque = std::find_if( + vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr& val_up) { + return val_up.get() == val; + }); + + TORCH_INTERNAL_ASSERT( + val_in_deque != vals_up_.end(), + "Wanted to remove a value but its unique ptr is missing."); + + vals_.erase(val); + vals_up_.erase(val_in_deque); + raw_ptrs_.erase((void*)val); +} + +//! Register the Val with this container +void IrContainer::registerVal(Val* val) { + if (inContainer(val)) { + return; + } + + vals_up_.emplace_back(std::unique_ptr(val)); + vals_.emplace(vals_up_.back().get()); + val->setName(IrContainerPasskey(), getValName(vals_up_.back()->vtype())); + raw_ptrs_.emplace((void*)vals_up_.back().get()); +} + +//! Register expr with this container. +void IrContainer::registerExpr(Expr* expr) { + if (inContainer(expr)) { + return; + } + exprs_up_.emplace_back(std::unique_ptr(expr)); + exprs_.emplace(exprs_up_.back().get()); + expr->setName(IrContainerPasskey(), getExprName()); + raw_ptrs_.emplace((void*)exprs_up_.back().get()); +} + +void IrContainer::clear() noexcept { + FUSER_PERF_SCOPE("IrContainer clear"); + vals_.clear(); + vals_up_.clear(); + exprs_.clear(); + exprs_up_.clear(); + raw_ptrs_.clear(); + + val_type_name_map_.clear(); + expr_name_counter_ = 0; +} + +bool IrContainer::inContainer(const Statement* stmt) const { + const void* const_void = (const void*)(stmt); + void* nonconst_void = const_cast(const_void); // NOLINT + if (raw_ptrs_.find(nonconst_void) == raw_ptrs_.end()) { + return false; + } + + TORCH_INTERNAL_ASSERT( + stmt->container() == this, + "Container claims to own stmt, but stmt disagrees."); + + Statement* nonconst_stmt = const_cast(stmt); // NOLINT + if (stmt->isExpr()) { + TORCH_INTERNAL_ASSERT( + exprs_.find(nonconst_stmt->as()) != exprs_.end(), + "Somehow container claims to and not to own an Expr."); + } + if (stmt->isVal()) { + TORCH_INTERNAL_ASSERT( + vals_.find(nonconst_stmt->as()) != vals_.end(), + "Somehow container claims to and not to own an Val."); + } + + return true; +} + +// Shortcuts for frequently used vals +Int* IrContainer::zeroVal() { + if (!zero_val_) { + auto zero_val = IrBuilder::create(this, 0); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == zero_val); + zero_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return zero_val_.get(); +} + +Int* IrContainer::oneVal() { + if (!one_val_) { + auto one_val = IrBuilder::create(this, 1); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == one_val); + one_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return one_val_.get(); +} + +Bool* IrContainer::falseVal() { + if (!false_val_) { + auto false_val = IrBuilder::create(this, false); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == false_val); + false_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return false_val_.get(); +} + +Bool* IrContainer::trueVal() { + if (!true_val_) { + auto true_val = IrBuilder::create(this, true); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == true_val); + true_val_ = std::unique_ptr(vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return true_val_.get(); +} + +NamedScalar* IrContainer::magicZeroVal() { + if (!magic_zero_val_) { + auto magic_zero = + IrBuilder::create(kMagicZeroName, DataType::Int); + TORCH_INTERNAL_ASSERT(vals_up_.back().get() == magic_zero); + magic_zero_val_ = std::unique_ptr( + vals_up_.back().release()->as()); + vals_up_.pop_back(); + } + return magic_zero_val_.get(); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_container.h b/torch/csrc/jit/codegen/cuda/ir_container.h new file mode 100644 index 00000000000..fb1aaeaf383 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ir_container.h @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class IrBuilderPasskey; +class ExprPasskey; +class OptOutMutator; + +class Int; +class Bool; +class NamedScalar; + +// Passkey for container to register names with statements +class IrContainerPasskey { + friend class IrContainer; + + private: + explicit IrContainerPasskey() {} +}; + +class TORCH_CUDA_CU_API IrContainer : public PolymorphicBase { + public: + IrContainer(); + + IrContainer(const IrContainer& other); + IrContainer(IrContainer&& other) noexcept; + + IrContainer& operator=(const IrContainer& other); + IrContainer& operator=(IrContainer&& other) noexcept; + + virtual ~IrContainer(); + + bool inContainer(const Statement* stmt) const; + + void assertInContainer(const Statement* stmt, const std::string& msg) const { + TORCH_CHECK( + inContainer(stmt), msg, " it was not found in the active container."); + } + + //! Return in insertion order + const std::deque deterministic_vals() const noexcept { + std::deque vals_deque; + std::transform( + vals_up_.begin(), + vals_up_.end(), + std::back_inserter(vals_deque), + [](const std::unique_ptr& val_up) { return val_up.get(); }); + return vals_deque; + } + + //! Register the Statement with this container + virtual void registerStmt(IrBuilderPasskey, Statement* stmt); + + //! Register the Val with this container + virtual void registerVal(IrBuilderPasskey, Val* val); + + //! Register expr with this container. + virtual void registerExpr(IrBuilderPasskey, Expr* expr); + + //! Allow expr's to register themselves with a container, this is only used + //! for broadcastOp so it can register itself in its constructor so root maps + //! can be built. + virtual void registerExpr(ExprPasskey, Expr* expr); + + //! Return the set of Exprs registered with this fusion. Warning: This will + //! return exprs outside inputs/outputs, so can be unsafe for use with + //! segmented fusions. + const std::unordered_set& unordered_exprs() const noexcept { + return exprs_; + } + + //! Return the set of Vals registered with this fusion + const std::unordered_set& vals() const noexcept { + return vals_; + } + + // Shortcuts for frequently used vals + Int* zeroVal(); + Int* oneVal(); + Bool* falseVal(); + Bool* trueVal(); + NamedScalar* magicZeroVal(); + + protected: + static IrCloner copy(const IrContainer* from, IrContainer* to); + + friend void swap(IrContainer& a, IrContainer& b) noexcept; + + // Let mutator remove Exprs. + friend OptOutMutator; + + virtual void removeExpr(Expr* expr); + + //! Completely remove val from the fusion, break all dependencies associated + //! with it + virtual void removeVal(Val* val); + + //! Register the Val with this container + virtual void registerVal(Val* val); + + //! Register expr with this container. + virtual void registerExpr(Expr* expr); + + StmtNameType getValName(ValType vtype) { + if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { + val_type_name_map_[vtype] = 0; + } + return val_type_name_map_[vtype]++; + } + + StmtNameType getExprName() { + return expr_name_counter_++; + } + + void clear() noexcept; + + // Deque of unique pointer is the memory owning data structure + std::deque> vals_up_; + + // A convenient set to return when we just need an unordered set to do + // something like check if a Val is in this container + std::unordered_set vals_; + + // Deque of unique pointer is the memory owning data structure + std::deque> exprs_up_; + + // A convenient set to return when we just need an unordered set to do + // something like check if an Expr is in this container + std::unordered_set exprs_; + + // Used to implement a generic "inContainer" that can be passed an invalid + // pointer. Specifically a pointer to a Statement owned by another container + // that has been freed. We can't check normally with the unordered_sets we + // already have because it would require a const_cast from a constant + // expr/val, or a dynamic cast from a Statement. + std::unordered_set raw_ptrs_; + + // Values names counters + std::unordered_map val_type_name_map_; + + // Expression names counter + StmtNameType expr_name_counter_ = 0; + + // Manually store some persistent, frequently used nodes. It's very + // challenging to do this anything but manually as detecting when a container + // may or may not have one of these vals is tricky. Specifically because if + // the container doesn't own it, it's hard to understand from the outside if + // the node may have been removed then re-registered. It could also be tricky + // to know when we're using a different container as in FusionCopy_test + // demonstrates deleting then creating containers can result in the same + // pointer for the container. + std::unique_ptr true_val_; + std::unique_ptr false_val_; + std::unique_ptr one_val_; + std::unique_ptr zero_val_; + std::unique_ptr magic_zero_val_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index 5ca8d54aaa9..7511fbd4d6d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -303,13 +304,13 @@ void IrGraphGenerator::generateScheduleGraph() { // Maybe not the best way to handle the root domain, but should be okay addArc( tv, - new TensorDomain(tv->getRootDomain()), + IrBuilder::create(tv->getRootDomain()), "[style=dashed, color=green, arrowhead=none]"); if (tv->domain()->hasRFactor()) addArc( tv, - new TensorDomain(tv->domain()->getRFactorDomain()), + IrBuilder::create(tv->domain()->getRFactorDomain()), "[style=dashed, color=green, arrowhead=none]"); } } diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index 1144d95eb15..f9b3adf703d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 02c319d3665..28478c64d91 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -19,6 +19,9 @@ namespace cuda { class WelfordResult; class ViewTransform; +class IrCloner; +class IrBuilderPasskey; + //! A Bool value //! //! This value can be a symbolic value (defined after the kernel @@ -26,17 +29,18 @@ class ViewTransform; //! class TORCH_CUDA_CU_API Bool : public Val { public: - Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {} + Bool(IrBuilderPasskey passkey); + + explicit Bool(IrBuilderPasskey passkey, bool value); - explicit Bool(bool value) - : Val(ValType::Scalar, DataType::Bool), maybe_value_{value} {} + explicit Bool(IrBuilderPasskey passkey, c10::optional value); Bool(const Bool* src, IrCloner* ir_cloner); bool isSymbolic() const { return !(maybe_value_.has_value()); } - bool isConst() const { + bool isConst() const final { return maybe_value_.has_value(); } c10::optional value() const { @@ -56,18 +60,18 @@ class TORCH_CUDA_CU_API Double : public Val { public: using ScalarType = double; - Double() - : Val(ValType::Scalar, DataType::Double), maybe_value_{c10::nullopt} {} + Double(IrBuilderPasskey passkey); + + explicit Double(IrBuilderPasskey passkey, ScalarType value); - explicit Double(ScalarType value) - : Val(ValType::Scalar, DataType::Double), maybe_value_{value} {} + explicit Double(IrBuilderPasskey passkey, c10::optional value); Double(const Double* src, IrCloner* ir_cloner); bool isSymbolic() const { return !(maybe_value_.has_value()); } - bool isConst() const { + bool isConst() const final { return maybe_value_.has_value(); } c10::optional value() const { @@ -86,17 +90,18 @@ class TORCH_CUDA_CU_API Int : public Val { public: using ScalarType = int64_t; - Int() : Val(ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {} + Int(IrBuilderPasskey passkey); - explicit Int(ScalarType value) - : Val(ValType::Scalar, DataType::Int), maybe_value_{value} {} + explicit Int(IrBuilderPasskey passkey, ScalarType value); + + explicit Int(IrBuilderPasskey passkey, c10::optional value); Int(const Int* src, IrCloner* ir_cloner); bool isSymbolic() const { return !(maybe_value_.has_value()); } - bool isConst() const { + bool isConst() const final { return maybe_value_.has_value(); } c10::optional value() const { @@ -152,14 +157,18 @@ class TVDomainGuard; class TORCH_CUDA_CU_API TensorView : public Val { public: TensorView( + IrBuilderPasskey passkey, TensorDomain* domain, DataType dtype, MemoryType mtype = MemoryType::Local); - explicit TensorView(const std::shared_ptr& tensor_type); + explicit TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& tensor_type); - explicit TensorView(const std::shared_ptr& jit_value) - : TensorView(jit_value->type()->cast()) {} + explicit TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& jit_value); TensorView(const TensorView* src, IrCloner* ir_cloner); @@ -187,6 +196,16 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! trivial reductions bool hasAnyReduction() const; + //! Returns true if this tensor is zero dimensional, + //! i.e. a wrapped scalar or an empty placeholder. + bool isZeroDim() const { + return nDims() == 0; + } + + //! Returns true if this tensor does not contain + //! any value. + bool isEmptyTensor() const; + c10::optional getReductionAxis() const; const std::vector& getRootDomain() const; @@ -210,6 +229,24 @@ class TORCH_CUDA_CU_API TensorView : public Val { size_t nDims() const; + // sets cpu_scalar_ value, which is special handling for CPU based zero-dim + // tensors (i.e. CPU Tensors that only have one value). This is only used if + // on an input value, otherwise ignored. This is important as special handling + // because these "scalars" should be type promoted as a tensor, but we want to + // avoid explicit copying of the data, so we want to pass the data value as a + // standard kernel argument value. + void setCpuScalar(bool is_cpu_scalar); + + // returns cpu_scalar_ value, which is special handling for CPU based zero-dim + // tensors (i.e. CPU Tensors that only have one value). This is only used if + // on an input value, otherwise ignored. This is important as special handling + // because these "scalars" should be type promoted as a tensor, but we want to + // avoid explicit copying of the data, so we want to pass the data value as a + // standard kernel argument value. + bool isCpuScalar() const { + return cpu_scalar_; + } + // Returns the position that this tensor is produced at relative to its axes. unsigned int getComputeAtPosition() const { return compute_at_pos_; @@ -356,6 +393,13 @@ class TORCH_CUDA_CU_API TensorView : public Val { return axes_to_swizzle_; } + // Apply double buffering transformation + void doubleBuffer(); + + bool isDoubleBuffered() const { + return is_double_buffered_; + } + friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; @@ -393,6 +437,14 @@ class TORCH_CUDA_CU_API TensorView : public Val { MemoryType memory_type_ = MemoryType::Local; SwizzleType swizzle_type_ = SwizzleType::NoSwizzle; std::vector axes_to_swizzle_; + bool is_double_buffered_ = false; + // special handling for CPU based zero-dim tensors (i.e. CPU Tensors that only + // have one value). This is only used if on an input value, otherwise ignored. + // This is important as special handling because these "scalars" should be + // type promoted as a tensor, but we want to avoid explicit copying of the + // data, so we want to pass the data value as a standard kernel argument + // value. + bool cpu_scalar_ = false; }; //! A simple TensorView builder diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 8fd4475d2dd..bb494148be2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -1,10 +1,11 @@ #pragma once -#include +#include #include #include #include +#include //! Nodes in here should generally not be used by users. They should be behind //! the scenes and users shouldn't have to be aware of what they do to use the @@ -20,6 +21,8 @@ namespace fuser { namespace cuda { class ViewTransform; +class Scope; +class IrCloner; //! Returns true if both v1 and v2 are scalars, are the same type of scalars, //! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both @@ -34,7 +37,7 @@ bool areEqualScalars(Val* v1, Val* v2); //! 4) split/merge class TORCH_CUDA_CU_API UnaryOp : public Expr { public: - UnaryOp(UnaryOpType type, Val* out, Val* in); + UnaryOp(IrBuilderPasskey, UnaryOpType type, Val* out, Val* in); UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); @@ -63,7 +66,7 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { //! 2) LT (A < B) class TORCH_CUDA_CU_API BinaryOp : public Expr { public: - BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs); + BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); @@ -97,7 +100,11 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { //! \param out The output tensor //! \param in The input tensor //! \param is_broadcast_dims True when output dim is a new broadcast domain - BroadcastOp(Val* out, Val* in, std::vector is_broadcast_dims); + BroadcastOp( + IrBuilderPasskey, + Val* out, + Val* in, + std::vector is_broadcast_dims); BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); @@ -138,7 +145,12 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { //! non-reduction/non-broadcast dimensions. class TORCH_CUDA_CU_API ReductionOp : public Expr { public: - ReductionOp(BinaryOpType reduction_op_type, Val* init, Val* out, Val* in); + ReductionOp( + IrBuilderPasskey, + BinaryOpType reduction_op_type, + Val* init, + Val* out, + Val* in); ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); @@ -169,6 +181,7 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { class TORCH_CUDA_CU_API WelfordOp : public Expr { public: WelfordOp( + IrBuilderPasskey, Val* out_avg, Val* out_var, Val* out_N, @@ -189,10 +202,6 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { return in_avg_; } - Val* init() const { - return init_avg_; - } - bool sameAs(const Statement* const other) const override; // Welford Accessors @@ -255,7 +264,11 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { class TORCH_CUDA_CU_API TransposeOp : public Expr { public: - TransposeOp(TensorView* out, TensorView* in, std::vector new2old); + TransposeOp( + IrBuilderPasskey, + TensorView* out, + TensorView* in, + std::vector new2old); TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); @@ -279,7 +292,13 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { class TORCH_CUDA_CU_API TernaryOp : public Expr { public: - TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3); + TernaryOp( + IrBuilderPasskey, + TernaryOpType type, + Val* out, + Val* in1, + Val* in2, + Val* in3); TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); @@ -317,7 +336,12 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { //! \param out //! \param in //! \param offsets - ShiftOp(Val* out, Val* in, std::vector offsets, bool pad); + ShiftOp( + IrBuilderPasskey, + Val* out, + Val* in, + std::vector offsets, + std::vector pad_width); ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); @@ -336,8 +360,14 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { return offsets_; } - bool pad() const { - return pad_; + const std::vector& padWidth() const { + return pad_width_; + } + + bool hasPadding() const { + return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto p) { + return p > 0; + }); } bool sameAs(const Statement* other) const override; @@ -349,17 +379,18 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { //! offsets_. The sign of each value indicates the direction of //! shifting. const std::vector offsets_; - const bool pad_; + const std::vector pad_width_; }; //! Gather a window around each element. class TORCH_CUDA_CU_API GatherOp : public Expr { public: GatherOp( + IrBuilderPasskey, Val* out, Val* in, - std::vector window_shape, - std::vector> pad_width); + std::vector window_shape, + std::vector> pad_width); GatherOp(const GatherOp* src, IrCloner* ir_cloner); @@ -381,20 +412,26 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { return pad_width_; } + bool hasPadding() const { + return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto& p) { + return p[0] > 0 || p[1] > 0; + }); + } + bool sameAs(const Statement* other) const override; private: Val* const out_ = nullptr; Val* const in_ = nullptr; //! Shape of a window gathered for each element. - std::vector window_shape_; + std::vector window_shape_; //! The size of zero-padding of each axis. - std::vector> pad_width_; + std::vector> pad_width_; }; class TORCH_CUDA_CU_API ViewOp : public Expr { public: - ViewOp(TensorView* out, TensorView* in); + ViewOp(IrBuilderPasskey, TensorView* out, TensorView* in); ViewOp(const ViewOp* src, IrCloner* ir_cloner); @@ -422,6 +459,7 @@ class IndexReferenceReplay; class TORCH_CUDA_CU_API IterDomain : public Val { public: IterDomain( + IrBuilderPasskey, Val* start, Val* extent, ParallelType parallel_type = ParallelType::Serial, @@ -429,6 +467,7 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool is_rfactor_domain = false); IterDomain( + IrBuilderPasskey, Val* start, Val* extent, Val* stop_offset, @@ -441,20 +480,7 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool sameAs(const Statement* other) const override; // Returns a new IterDomain matching properties of this - // TODO: parallel_method->getParallelType - IterDomain* clone() const { - auto cloned = new IterDomain( - start(), - extent(), - stopOffset(), - getParallelType(), - getIterType(), - isRFactorProduct()); - - cloned->is_padded_dimension_ = is_padded_dimension_; - cloned->padded_to_size_ = padded_to_size_; - return cloned; - } + IterDomain* clone() const; //! Clone a vector domains static std::vector clone( @@ -631,6 +657,11 @@ class TORCH_CUDA_CU_API IterDomain : public Val { //! domain. std::pair stridedSplit(int factor); + // TODO: Remove + bool isSimple() const { + return definition() == nullptr; + } + protected: friend TensorDomain; friend ReplayTransformations; @@ -647,6 +678,10 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool is_rfactor_domain_ = false; bool is_padded_dimension_ = false; c10::optional padded_to_size_ = c10::nullopt; + + // TODO: Remove only used in kernel IR because IterDomains don't maintain + // definitions of split/merge. + bool is_simple_ = true; }; //! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every @@ -666,15 +701,18 @@ class TORCH_CUDA_CU_API IterDomain : public Val { class TORCH_CUDA_CU_API TensorDomain : public Val { public: explicit TensorDomain( + IrBuilderPasskey, std::vector root_domain, std::vector contiguity = std::vector()); TensorDomain( + IrBuilderPasskey, std::vector root_domain, std::vector domain, std::vector contiguity = std::vector()); TensorDomain( + IrBuilderPasskey, std::vector root_domain, std::vector rfactor_domain, std::vector domain, @@ -718,6 +756,8 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { bool hasReduction() const; bool hasBlockReduction() const; bool hasGridReduction() const; + bool hasBlockBroadcast() const; + bool hasGridBroadcast() const; bool hasBroadcast() const; bool hasRFactor() const; bool hasVectorize() const; @@ -821,6 +861,7 @@ class TORCH_CUDA_CU_API Split : public Expr { // start_offset and stop_offset are distance from the left end and // right ends, respectively. Split( + IrBuilderPasskey, IterDomain* outer, IterDomain* inner, IterDomain* in, @@ -881,12 +922,13 @@ class TORCH_CUDA_CU_API Split : public Expr { //! dictate which will be traversed first (inner). Both IterDomains must be of //! the same iter or reduction type, as well as the same parallelization //! strategy if there is one -//! -//! \todo Should this be a unary op type? -//! class TORCH_CUDA_CU_API Merge : public Expr { public: - Merge(IterDomain* out, IterDomain* outer, IterDomain* inner); + Merge( + IrBuilderPasskey, + IterDomain* out, + IterDomain* outer, + IterDomain* inner); Merge(const Merge* src, IrCloner* ir_cloner); @@ -918,9 +960,7 @@ class TORCH_CUDA_CU_API Merge : public Expr { //! class TORCH_CUDA_CU_API NamedScalar : public Val { public: - // NOLINTNEXTLINE(modernize-pass-by-value) - NamedScalar(std::string name, DataType dtype) - : Val(ValType::NamedScalar, dtype), name_(name) {} + NamedScalar(IrBuilderPasskey passkey, std::string name, DataType dtype); NamedScalar(const NamedScalar* src, IrCloner* ir_cloner); @@ -931,9 +971,11 @@ class TORCH_CUDA_CU_API NamedScalar : public Val { bool sameAs(const Statement* other) const override; //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x) + //! WARNING: Only works with Fusion container at the moment static NamedScalar* getParallelDim(ParallelType p_type); //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x) + //! WARNING: Only works with Fusion container at the moment static NamedScalar* getParallelIndex(ParallelType p_type); //! Return the parallel type of this NamedScalar if it is an extent of a diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index a553c59fc2b..8c0e1022308 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -14,6 +15,23 @@ namespace jit { namespace fuser { namespace cuda { +namespace { +const char* boolLiteral(bool value) { + return value ? "true" : "false"; +} + +std::string varName(const Val* val) { + std::stringstream value_name; + if (val == nullptr) { + value_name << "$nullptr"; + } else { + value_name << val->name(); + } + return value_name.str(); +} + +} // namespace + // Make sure we can inline something, before we attempt to. static void checkInlineable(const Expr* expr) { for (auto input : expr->inputs()) { @@ -49,6 +67,70 @@ void IrPrinter::handle(Fusion* fusion) { } } +void IrPrinter::handle(const kir::Kernel* kernel) { + TORCH_CHECK(kernel != nullptr); + + // kernel declaration + os_ << "\nKERNEL ("; + for (auto in : kernel->inputs()) { + handle(in); + if (in != kernel->inputs().back()) { + os_ << ", "; + } + } + os_ << ") -> ("; + for (auto out : kernel->outputs()) { + handle(out); + if (out != kernel->outputs().back()) { + os_ << ", "; + } + } + os_ << ") :\n"; + + // kernel body + indent_size_++; + for (auto expr : kernel->topLevelExprs()) { + handle(expr); + } + indent_size_--; + os_ << "END.\n\n"; +} + +void IrPrinter::handle(kir::Kernel& kernel) { + handle(&kernel); +} + +void IrPrinter::handleScope(const kir::Scope& scope) { + // Save the uses of the parent scope + indent_size_++; + for (auto expr : scope.exprs()) { + handle(expr); + } + indent_size_--; +} + +void IrPrinter::handle(const IterDomain* id) { + os_ << id->getIterType(); + os_ << id->getParallelType(); + os_ << varName(id); + os_ << "{"; + if (!id->start()->isZeroInt()) { + print_inline(id->start()); + os_ << " : "; + } + if (id->stop() != id->extent()) { + print_inline(id->stop()); + os_ << " : "; + } + print_inline(id->extent()); + os_ << "}"; + if (id->isRFactorProduct()) + os_ << "rf"; + if (id->hasPaddingToMultipleOfWarp()) { + os_ << "_p"; + } +} + void IrPrinter::handle(const TensorDomain* td) { if (td->nDims() == 0) { os_ << "[ 0 ]"; @@ -65,9 +147,9 @@ void IrPrinter::handle(const TensorDomain* td) { void IrPrinter::handle(const TensorView* tv) { if (tv->nDims() == 0) { - os_ << typePrefix(tv->getDataType().value()) << tv->name(); + os_ << typePrefix(tv->getDataType().value()) << varName(tv); } else { - os_ << "T" << tv->name(); + os_ << "T" << varName(tv); switch (tv->getMemoryType()) { case MemoryType::Global: os_ << "_g"; @@ -94,28 +176,6 @@ void IrPrinter::handle(const TensorView* tv) { } } -void IrPrinter::handle(const IterDomain* id) { - os_ << id->getIterType(); - os_ << id->getParallelType(); - os_ << id->name(); - os_ << "{"; - if (!id->start()->isZeroInt()) { - print_inline(id->start()); - os_ << " : "; - } - if (id->stop() != id->extent()) { - print_inline(id->stop()); - os_ << " : "; - } - print_inline(id->extent()); - os_ << "}"; - if (id->isRFactorProduct()) - os_ << "rf"; - if (id->hasPaddingToMultipleOfWarp()) { - os_ << "_p"; - } -} - void IrPrinter::handle(const Bool* b) { if (print_inline_ && b->definition() != nullptr) { os_ << "( "; @@ -124,10 +184,9 @@ void IrPrinter::handle(const Bool* b) { return; } - if (b->isSymbolic()) { - os_ << "b" << b->name(); - } else { - os_ << "bool(" << *(b->value()) << ")"; + os_ << "b" << varName(b); + if (b->isConst()) { + os_ << "(" << (b->value().value() ? "true" : "false") << ")"; } } @@ -140,7 +199,7 @@ void IrPrinter::handle(const Double* d) { } if (d->isSymbolic()) { - os_ << "d" << d->name(); + os_ << "d" << varName(d); } else { os_ << "double(" << std::setprecision( @@ -160,30 +219,20 @@ void IrPrinter::handle(const Int* i) { } if (i->isSymbolic()) { - os_ << "i" << i->name(); + os_ << "i" << varName(i); } else { os_ << *(i->value()); } } -void IrPrinter::handle(const NamedScalar* i) { - os_ << i->name(); -} - -static bool isTV(const Val* val) { - return val->getValType().value() == ValType::TensorView; -} - -// Check if we're a TensorView op that we can generate code for. -static bool isTVOp(const Expr* expr) { - return expr->outputs().size() == 1 && isTV(expr->outputs().front()); +void IrPrinter::handle(const NamedScalar* ns) { + os_ << ns->name(); } void IrPrinter::handle(const UnaryOp* uop) { - bool istvop = isTVOp(uop); + bool istvop = ir_utils::isTvOp(uop); if (!print_inline_) { - indent(); - os_ << uop->out(); + indent() << uop->out(); if (istvop) { os_ << "\n"; indent_size_++; @@ -230,10 +279,9 @@ void IrPrinter::handle(const UnaryOp* uop) { } void IrPrinter::handle(const BinaryOp* bop) { - bool istvop = isTVOp(bop); + bool istvop = ir_utils::isTvOp(bop); if (!print_inline_) { - indent(); - os_ << bop->out(); + indent() << bop->out(); // tensor operations tend to be long, break them up into multiple lines if (istvop) { @@ -286,7 +334,7 @@ void IrPrinter::handle(const BinaryOp* bop) { } void IrPrinter::handle(const TernaryOp* top) { - bool istvop = isTVOp(top); + bool istvop = ir_utils::isTvOp(top); if (!print_inline_) { indent(); os_ << top->out(); @@ -327,18 +375,16 @@ void IrPrinter::handle(const TernaryOp* top) { } void IrPrinter::handle(const ReductionOp* rop) { - indent(); - os_ << rop->out() << " = reduction( " << rop->in() - << ", op = " << rop->getReductionOpType() - << ", initial value = " << rop->init() << " )\n"; + indent() << rop->out() << " = reduction( " << rop->in() + << ", op = " << rop->getReductionOpType() + << ", initial value = " << rop->init() << " )\n"; } void IrPrinter::handle(const WelfordOp* wop) { - indent(); - os_ << wop->outAvg() << "(Avg),\n" - << wop->outVar() << "(Var),\n" - << wop->outN() << "(Count)" - << "\n = Welford ( "; + indent() << wop->outAvg() << "(Avg),\n" + << wop->outVar() << "(Var),\n" + << wop->outN() << "(Count)" + << "\n = Welford ( "; if (wop->singleValue()) { os_ << wop->inAvg() << "(Avg), "; } else { @@ -353,24 +399,48 @@ void IrPrinter::handle(const WelfordOp* wop) { } void IrPrinter::handle(const BroadcastOp* bop) { - indent(); - os_ << bop->out() << " = broadcast( " << bop->in() << " )\n"; + indent() << bop->out() << " = broadcast( " << bop->in() << " )\n"; +} + +void IrPrinter::handle(const Split* s) { + os_ << (s->innerSplit() ? "Split: " : "Outer split: "); + handle(s->in()); + os_ << " by factor " << s->factor() << " -> "; + handle(s->outer()); + os_ << ", "; + handle(s->inner()); + if (s->startOffset()) { + os_ << ", start offset: "; + handle(s->startOffset()); + } + if (s->stopOffset()) { + os_ << ", stop offset: "; + handle(s->stopOffset()); + } + os_ << "\n"; +} + +void IrPrinter::handle(const Merge* m) { + os_ << "Merge: "; + handle(m->outer()); + os_ << " and "; + handle(m->inner()); + os_ << " -> "; + handle(m->out()); + os_ << "\n"; } void IrPrinter::handle(const TransposeOp* top) { - indent(); - os_ << top->out() << " = transpose( " << top->in() << " )\n"; + indent() << top->out() << " = transpose( " << top->in() << " )\n"; } void IrPrinter::handle(const ShiftOp* sop) { - indent(); - os_ << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets() - << "}, padding = " << (sop->pad() ? "true" : "false") << " )\n"; + indent() << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets() + << "}, {" << sop->padWidth() << "} )\n"; } void IrPrinter::handle(const GatherOp* op) { - indent(); - os_ << op->out() << " = gather( " << op->in() << ", {"; + indent() << op->out() << " = gather( " << op->in() << ", {"; bool no_comma = true; for (const auto& s : op->windowShape()) { if (!no_comma) { @@ -392,36 +462,187 @@ void IrPrinter::handle(const GatherOp* op) { } void IrPrinter::handle(const ViewOp* top) { - indent(); - os_ << top->out() << " = view( " << top->in() << " )\n"; + indent() << top->out() << " = view( " << top->in() << " )\n"; } -void IrPrinter::handle(const Split* s) { - os_ << (s->innerSplit() ? "Split: " : "Outer split: "); - handle(s->in()); - os_ << " by factor " << s->factor() << " -> "; - handle(s->outer()); - os_ << ", "; - handle(s->inner()); - if (s->startOffset()) { - os_ << ", start offset: "; - handle(s->startOffset()); +void IrPrinter::handle(const kir::Predicate* node) { + switch (node->predicate_type()) { + case PredicateType::Inline: { + os_ << "Inline_Predicate"; + break; + } + case PredicateType::Manual: { + os_ << node->value(); + break; + } + case PredicateType::Misaligned: { + os_ << "Misaligned_Predicate"; + break; + } + case PredicateType::Padding: { + os_ << "Padding_Predicate"; + break; + } + case PredicateType::Shift: { + os_ << "Shift_Predicate"; + break; + } + case PredicateType::Unswitch: { + os_ << "Unswitch_Predicate"; + break; + } + case PredicateType::Vectorize: { + os_ << "Vectorize_Predicate"; + break; + } + default: + break; } - if (s->stopOffset()) { - os_ << ", stop offset: "; - handle(s->stopOffset()); +} + +void IrPrinter::handle(const kir::TensorIndex* ti) { + os_ << "T" << varName(ti); + switch (ti->view()->getMemoryType()) { + case MemoryType::Global: + os_ << "_g"; + break; + case MemoryType::Shared: + os_ << "_s"; + break; + case MemoryType::Local: + os_ << "_l"; + break; } + os_ << "["; + for (auto index : ti->indices()) { + print_inline(index); + if (index != ti->indices().back()) { + os_ << ", "; + } + } + os_ << "]"; + os_ << " view( T" << varName(ti->view()) << " )"; +} + +void IrPrinter::handle(const kir::Allocate* node) { + indent(); + handle(node->buffer()); + os_ << " = ALLOCATE(" + << "mem_type=" << node->memoryType() << ", " + << "size="; + print_inline(node->size()); + os_ << ", " + << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n"; + if (node->alias() != nullptr) { + indent() << kTab << ".alias="; + handle(node->alias()->buffer()); + os_ << "\n"; + } +} + +void IrPrinter::handle(const kir::Sync* node) { + indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync()) + << ")\n"; +} + +void IrPrinter::handle(const kir::ForLoop* node) { + indent() << "FOR "; + handle(node->index()); + os_ << " in "; + handle(node->iter_domain()); + os_ << ":\n"; + handleScope(node->body()); +} + +void IrPrinter::handle(const kir::IfThenElse* node) { + indent() << "IF "; + handle(node->predicate()); + os_ << ":\n"; + handleScope(node->thenBody()); + if (node->hasElse()) { + indent() << "ELSE:\n"; + handleScope(node->elseBody()); + } +} + +void IrPrinter::handle(const kir::GridBroadcast* node) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} + +void IrPrinter::handle(const kir::GridReduction* node) { + const auto* reduction_op = node->reduction_op(); + indent(); + handle(reduction_op->out()); + os_ << " = " + << "GRID_REDUCTION(op='" << reduction_op->getReductionOpType() << "'" + << ", in="; + handle(reduction_op->in()); + os_ << ", init="; + handle(reduction_op->init()); + os_ << ", pred="; + handle(reduction_op->predicate()); + os_ << ")\n"; + indent() << kTab << ".reduction_buffer="; + handle(node->reduction_buffer()->buffer()); + os_ << "\n"; + indent() << kTab << ".sync_buffer="; + handle(node->sync_buffer()->buffer()); + os_ << "\n"; + indent() << kTab << ".grid_pred="; + handle(node->predicate()); os_ << "\n"; } -void IrPrinter::handle(const Merge* m) { - os_ << "Merge: "; - handle(m->outer()); - os_ << " and "; - handle(m->inner()); - os_ << " -> "; - handle(m->out()); +void IrPrinter::handle(const kir::GridWelford* node) { + const auto* welford_op = node->welford_op(); + indent(); + handle(welford_op->outVar()); + os_ << ","; + handle(welford_op->outAvg()); + os_ << ","; + handle(welford_op->outN()); + os_ << " = " + << "GRID_WELFORD(" + << "inAvg="; + handle(welford_op->inAvg()); + if (!welford_op->inN()->isOneInt()) { + indent() << ", inVar="; + handle(welford_op->inVar()); + } + indent() << ", inN="; + handle(welford_op->inN()); + if (!welford_op->initN()->isZeroInt()) { + indent() << ", initVar="; + handle(welford_op->initVar()); + os_ << " initAvg="; + handle(welford_op->initAvg()); + os_ << " initN="; + handle(welford_op->initN()); + } + indent() << ", pred="; + handle(welford_op->predicate()); + os_ << ")\n"; + indent() << kTab << ".var_buffer="; + handle(node->var_buffer()->buffer()); + os_ << ".avg_buffer="; + handle(node->avg_buffer()->buffer()); + os_ << ".n_buffer="; + handle(node->N_buffer()->buffer()); + os_ << "\n"; + indent() << kTab << ".sync_buffer="; + handle(node->sync_buffer()->buffer()); os_ << "\n"; + indent() << kTab << ".grid_pred="; + handle(node->predicate()); + os_ << "\n"; +} + +void IrPrinter::handle(const kir::InitMagicZero* node) { + indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; +} + +void IrPrinter::handle(const kir::UpdateMagicZero* node) { + indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; } void IrTransformPrinter::handle(Fusion* f) { @@ -450,7 +671,7 @@ void IrTransformPrinter::printTransforms(TensorView* tv) { os() << ")\n"; for (auto exp : all_exp) { - os() << " "; + os() << " "; IrPrinter::handle(exp); } } diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index c080c3f8f99..f8c07886114 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -13,21 +13,30 @@ namespace jit { namespace fuser { namespace cuda { +class Fusion; +namespace kir { +class Kernel; +class Scope; +} // namespace kir + //! Define pretty printing functions for IR nodes //! //! This class is intended for debug printing, so it attempts //! to handle invalid states as well. //! class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { + static constexpr char const* kTab = " "; + public: explicit IrPrinter(std::ostream& os) : os_(os) {} // Indent the generated code - void indent() { + std::ostream& indent() { for (const auto i : c10::irange(indent_size_)) { (void)i; // Suppress unused variable warning os_ << " "; } + return os_; } void resetIndent() { @@ -38,6 +47,8 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { return print_inline_; } + using OptInConstDispatch::handle; + virtual void handle(Fusion* f); // handle calls some non const fusion ops, @@ -52,30 +63,50 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { handle(&f); } - void handle(const Statement* s) override; - void handle(const Val* v) override; - void handle(const Expr* e) override; - - void handle(const TensorDomain*) override; - void handle(const TensorView*) override; - void handle(const IterDomain*) override; - - void handle(const Bool*) override; - void handle(const Double*) override; - void handle(const Int*) override; - void handle(const NamedScalar*) override; - - void handle(const UnaryOp*) override; - void handle(const BinaryOp*) override; - void handle(const TernaryOp*) override; - void handle(const ReductionOp*) override; - void handle(const WelfordOp*) override; - void handle(const BroadcastOp*) override; - void handle(const TransposeOp*) override; - void handle(const ShiftOp*) override; - void handle(const GatherOp*) override; - void handle(const ViewOp*) override; - + virtual void handle(const kir::Kernel* kernel); + virtual void handle(kir::Kernel& kernel); + + void handleScope(const kir::Scope& scope); + + void handle(const Statement* s) final; + void handle(const Val* v) final; + void handle(const Expr* e) final; + + void handle(const IterDomain*) final; + void handle(const TensorDomain*) final; + void handle(const TensorView*) final; + + void handle(const Bool*) final; + void handle(const Double*) final; + void handle(const Int*) final; + void handle(const NamedScalar*) final; + + void handle(const UnaryOp*) final; + void handle(const BinaryOp*) final; + void handle(const TernaryOp*) final; + void handle(const ReductionOp*) final; + void handle(const WelfordOp*) final; + void handle(const BroadcastOp*) final; + void handle(const TransposeOp*) final; + void handle(const ShiftOp*) final; + void handle(const GatherOp*) final; + void handle(const ViewOp*) final; + + void handle(const kir::Predicate*) final; + void handle(const kir::TensorIndex*) final; + + void handle(const kir::GridBroadcast*) final; + void handle(const kir::GridReduction*) final; + void handle(const kir::GridWelford*) final; + void handle(const kir::ForLoop*) final; + void handle(const kir::IfThenElse*) final; + void handle(const kir::Allocate*) final; + void handle(const kir::Sync*) final; + void handle(const kir::InitMagicZero*) final; + void handle(const kir::UpdateMagicZero*) final; + + // IR math printer overrides these to prevent them from printing, keep + // override void handle(const Split*) override; void handle(const Merge*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 1465a88bef3..884b6a6e0ec 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -38,19 +40,19 @@ class ScalarCheck : OptInConstDispatch { } private: - void handle(const Bool* b) override { + void handle(const Bool* b) final { same_ = v1_->as()->sameAs(v2_->as()); } - void handle(const Double* d) override { + void handle(const Double* d) final { same_ = v1_->as()->sameAs(v2_->as()); } - void handle(const Int* i) override { + void handle(const Int* i) final { same_ = v1_->as()->sameAs(v2_->as()); } - void handle(const NamedScalar* ns) override { + void handle(const NamedScalar* ns) final { same_ = v1_->as()->sameAs(v2_->as()); } @@ -70,6 +72,16 @@ bool areEqualScalars(Val* v1, Val* v2) { return ScalarCheck::sameAs(v1, v2); } +Bool::Bool(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DataType::Bool), + maybe_value_{c10::nullopt} {} + +Bool::Bool(IrBuilderPasskey passkey, bool value) + : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {} + +Bool::Bool(IrBuilderPasskey passkey, c10::optional value) + : Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {} + Bool::Bool(const Bool* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -87,6 +99,16 @@ bool Bool::sameAs(const Statement* other) const { return false; } +Double::Double(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DataType::Double), + maybe_value_{c10::nullopt} {} + +Double::Double(IrBuilderPasskey passkey, ScalarType value) + : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} + +Double::Double(IrBuilderPasskey passkey, c10::optional value) + : Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {} + Double::Double(const Double* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -103,6 +125,16 @@ bool Double::sameAs(const Statement* other) const { return false; } +Int::Int(IrBuilderPasskey passkey) + : Val(passkey, ValType::Scalar, DataType::Int), + maybe_value_{c10::nullopt} {} + +Int::Int(IrBuilderPasskey passkey, ScalarType value) + : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {} + +Int::Int(IrBuilderPasskey passkey, c10::optional value) + : Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {} + Int::Int(const Int* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} @@ -120,11 +152,13 @@ bool Int::sameAs(const Statement* other) const { return false; } -UnaryOp::UnaryOp(UnaryOpType type, Val* out, Val* in) - : Expr(ExprType::UnaryOp), unary_op_type_{type}, out_{out}, in_{in} { +UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in) + : Expr(passkey, ExprType::UnaryOp), + unary_op_type_{type}, + out_{out}, + in_{in} { addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) @@ -146,8 +180,13 @@ bool UnaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -BinaryOp::BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs) - : Expr(ExprType::BinaryOp), +BinaryOp::BinaryOp( + IrBuilderPasskey passkey, + BinaryOpType type, + Val* out, + Val* lhs, + Val* rhs) + : Expr(passkey, ExprType::BinaryOp), binary_op_type_{type}, out_{out}, lhs_{lhs}, @@ -155,7 +194,6 @@ BinaryOp::BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs) addOutput(out); addInput(lhs); addInput(rhs); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) @@ -178,8 +216,14 @@ bool BinaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -TernaryOp::TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3) - : Expr(ExprType::TernaryOp), +TernaryOp::TernaryOp( + IrBuilderPasskey passkey, + TernaryOpType type, + Val* out, + Val* in1, + Val* in2, + Val* in3) + : Expr(passkey, ExprType::TernaryOp), ternary_op_type_{type}, out_{out}, in1_{in1}, @@ -189,7 +233,6 @@ TernaryOp::TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3) addInput(in1); addInput(in2); addInput(in3); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) @@ -213,8 +256,12 @@ bool TernaryOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -BroadcastOp::BroadcastOp(Val* out, Val* in, std::vector is_broadcast_dims) - : Expr(ExprType::BroadcastOp), +BroadcastOp::BroadcastOp( + IrBuilderPasskey passkey, + Val* out, + Val* in, + std::vector is_broadcast_dims) + : Expr(passkey, ExprType::BroadcastOp), out_(out), in_(in), is_broadcast_dims_(std::move(is_broadcast_dims)) { @@ -226,12 +273,18 @@ BroadcastOp::BroadcastOp(Val* out, Val* in, std::vector is_broadcast_dims) auto in_type = in->getValType().value(); TORCH_INTERNAL_ASSERT( - out_type == ValType::TensorView && in_type == ValType::TensorView, + (out_type == ValType::TensorView && in_type == ValType::TensorView) || + (out_type == ValType::TensorIndex && in_type == ValType::TensorIndex), "Cannot braodcast a non-tensor object."); addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); + + if (!out->isA() || !in->isA()) { + return; + } + + passkey.ir_container_->registerExpr(exprPasskey(), this); // This is a generic check that root dims of a consumer and producer match. // Maybe we shouldn't relegate it to this constructor. @@ -294,37 +347,44 @@ bool BroadcastOp::sameAs(const Statement* other) const { } ReductionOp::ReductionOp( + IrBuilderPasskey passkey, BinaryOpType reduction_op_type, Val* init, Val* out, Val* in) - : Expr(ExprType::ReductionOp), + : Expr(passkey, ExprType::ReductionOp), reduction_op_type_(reduction_op_type), init_(init), out_(out), in_(in) { - TORCH_CHECK(out->getValType().value() == ValType::TensorView); + TORCH_CHECK( + out->getValType().value() == ValType::TensorView || + out->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - in->getValType() == ValType::TensorView && - out->getValType() == ValType::TensorView, + (in->getValType() == ValType::TensorView && + out->getValType() == ValType::TensorView) || + (in->getValType() == ValType::TensorIndex && + out->getValType() == ValType::TensorIndex), "Reduction operation was created that does not have tensor inputs and outputs."); - TORCH_INTERNAL_ASSERT( - TensorDomain::noReductions(in->as()->getMaybeRFactorDomain()) - .size() == out->as()->getRootDomain().size(), - "Reduction operation created with mismatched domains."); - + if (in->isA()) { + TORCH_INTERNAL_ASSERT( + TensorDomain::noReductions( + in->as()->getMaybeRFactorDomain()) + .size() == out->as()->getRootDomain().size(), + "Reduction operation created with mismatched domains."); + } TORCH_INTERNAL_ASSERT( init->isConstScalar(), "Tried to create a reduction operation whith an initial value that isn't a constant."); addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } WelfordOp::WelfordOp( + IrBuilderPasskey passkey, Val* out_avg, Val* out_var, Val* out_N, @@ -334,7 +394,7 @@ WelfordOp::WelfordOp( Val* in_avg, Val* in_var, Val* in_N) - : Expr(ExprType::WelfordOp), + : Expr(passkey, ExprType::WelfordOp), out_avg_(out_avg), out_var_(out_var), out_N_(out_N), @@ -345,9 +405,15 @@ WelfordOp::WelfordOp( in_var_(in_var), in_N_(in_N) { // Check output type - TORCH_INTERNAL_ASSERT(out_avg->getValType().value() == ValType::TensorView); - TORCH_INTERNAL_ASSERT(out_var->getValType().value() == ValType::TensorView); - TORCH_INTERNAL_ASSERT(out_N->getValType().value() == ValType::TensorView); + TORCH_INTERNAL_ASSERT( + out_avg->getValType().value() == ValType::TensorView || + out_avg->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT( + out_var->getValType().value() == ValType::TensorView || + out_var->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT( + out_N->getValType().value() == ValType::TensorView || + out_N->getValType().value() == ValType::TensorIndex); // check initial value TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar); @@ -356,22 +422,32 @@ WelfordOp::WelfordOp( // initial value with a count of 1 is un-common enough that I'll push // the responsibility of creating all-zero var tensors to the user TORCH_INTERNAL_ASSERT( - init_avg && init_avg->getValType().value() == ValType::TensorView); + init_avg && + (init_avg->getValType().value() == ValType::TensorView || + init_avg->getValType().value() == ValType::TensorIndex)); TORCH_INTERNAL_ASSERT( - init_var && init_var->getValType().value() == ValType::TensorView); + init_var && + (init_var->getValType().value() == ValType::TensorView || + init_var->getValType().value() == ValType::TensorIndex)); } TORCH_INTERNAL_ASSERT( - in_avg && in_avg->getValType().value() == ValType::TensorView); + in_avg && + (in_avg->getValType().value() == ValType::TensorView || + in_avg->getValType().value() == ValType::TensorIndex), + in_avg->getValType().value()); // check input TORCH_INTERNAL_ASSERT( in_N->getValType().value() == ValType::Scalar || - in_N->getValType().value() == ValType::TensorView); + in_N->getValType().value() == ValType::TensorView || + in_N->getValType().value() == ValType::TensorIndex); if (!in_N->isOneInt()) { // when input is only one value, only the value is required through avg // input the var part is implicitly 0 and codegen will handle that. TORCH_INTERNAL_ASSERT( - in_var && in_var->getValType().value() == ValType::TensorView); + in_var && + (in_var->getValType().value() == ValType::TensorView || + in_var->getValType().value() == ValType::TensorIndex)); } addOutput(out_avg); @@ -384,8 +460,6 @@ WelfordOp::WelfordOp( addInput(in_var); } addInput(in_N); - - name_ = FusionGuard::getCurFusion()->registerExpr(this); } WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) @@ -444,10 +518,11 @@ bool ReductionOp::sameAs(const Statement* other) const { } TransposeOp::TransposeOp( + IrBuilderPasskey passkey, TensorView* out, TensorView* in, std::vector new2old) - : Expr(ExprType::TransposeOp), + : Expr(passkey, ExprType::TransposeOp), out_(out), in_(in), new2old_(std::move(new2old)) { @@ -481,7 +556,6 @@ TransposeOp::TransposeOp( addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) @@ -490,12 +564,17 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), new2old_(src->new2old_) {} -ShiftOp::ShiftOp(Val* out, Val* in, std::vector offsets, bool pad) - : Expr(ExprType::ShiftOp), +ShiftOp::ShiftOp( + IrBuilderPasskey passkey, + Val* out, + Val* in, + std::vector offsets, + std::vector pad_width) + : Expr(passkey, ExprType::ShiftOp), out_(out), in_(in), offsets_(std::move(offsets)), - pad_(pad) { + pad_width_(std::move(pad_width)) { // clang-tidy complains about out_ that it may be null. TORCH_INTERNAL_ASSERT(out_ != nullptr); TORCH_INTERNAL_ASSERT(in_ != nullptr); @@ -514,9 +593,15 @@ ShiftOp::ShiftOp(Val* out, Val* in, std::vector offsets, bool pad) "Invalid offset vector: ", offsets_); + TORCH_INTERNAL_ASSERT( + pad_width_.size() == + TensorDomain::noReductions(in_->as()->getRootDomain()) + .size(), + "Invalid padding width vector: ", + pad_width_); + addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) @@ -524,7 +609,7 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)), offsets_(src->offsets_), - pad_(src->pad_) {} + pad_width_(src->pad_width_) {} bool ShiftOp::sameAs(const Statement* other) const { if (this == other) { @@ -541,11 +626,12 @@ bool ShiftOp::sameAs(const Statement* other) const { } GatherOp::GatherOp( + IrBuilderPasskey passkey, Val* out, Val* in, - std::vector window_shape, - std::vector> pad_width) - : Expr(ExprType::GatherOp), + std::vector window_shape, + std::vector> pad_width) + : Expr(passkey, ExprType::GatherOp), out_(out), in_(in), window_shape_(std::move(window_shape)), @@ -578,28 +664,14 @@ GatherOp::GatherOp( addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) { - std::transform( - src->window_shape_.begin(), - src->window_shape_.end(), - std::back_inserter(window_shape_), - [&ir_cloner](const auto& x) { return ir_cloner->clone(x); }); - for (const auto& pad : src->pad_width_) { - std::vector pad_clone; - std::transform( - pad.begin(), - pad.end(), - std::back_inserter(pad_clone), - [&ir_cloner](const auto& x) { return ir_cloner->clone(x); }); - pad_width_.push_back(pad_clone); - } -} + in_(ir_cloner->clone(src->in_)), + window_shape_(src->window_shape_), + pad_width_(src->pad_width_) {} bool GatherOp::sameAs(const Statement* other) const { if (this == other) { @@ -609,23 +681,10 @@ bool GatherOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (windowShape().size() != other_op->windowShape().size()) { - return false; - } - for (const auto i : c10::irange(windowShape().size())) { - if (!windowShape()[i]->sameAs(other_op->windowShape()[i])) { - return false; - } - } - if (padWidth().size() != other_op->padWidth().size()) { + if (windowShape() != other_op->windowShape() || + padWidth() != other_op->padWidth()) { return false; } - for (const auto i : c10::irange(padWidth().size())) { - if (!padWidth()[i][0]->sameAs(other_op->padWidth()[i][0]) || - !padWidth()[i][1]->sameAs(other_op->padWidth()[i][1])) { - return false; - } - } return Expr::sameAs(other); } @@ -638,11 +697,10 @@ int GatherOp::gatherAxis(int axis) const { return int(windowShape().size()) + axis; } -ViewOp::ViewOp(TensorView* out, TensorView* in) - : Expr(ExprType::ViewOp), out_(out), in_(in) { +ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) + : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) { addOutput(out); addInput(in); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) @@ -651,12 +709,14 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)) {} IterDomain::IterDomain( + IrBuilderPasskey passkey, Val* start, Val* extent, ParallelType parallel_type, IterType iter_type, bool is_rfactor_domain) : IterDomain( + passkey, start, extent, nullptr, @@ -665,16 +725,19 @@ IterDomain::IterDomain( is_rfactor_domain) {} IterDomain::IterDomain( + IrBuilderPasskey passkey, Val* start, Val* extent, Val* stop_offset, ParallelType parallel_type, IterType iter_type, bool is_rfactor_domain) - : Val(ValType::IterDomain, DataType::Int, false), + : Val(passkey, ValType::IterDomain, DataType::Int), start_(start), extent_(extent), - stop_offset_(stop_offset == nullptr ? new Int(0) : stop_offset), + stop_offset_( + stop_offset == nullptr ? passkey.ir_container_->zeroVal() + : stop_offset), parallel_type_(parallel_type), iter_type_(iter_type), is_rfactor_domain_(is_rfactor_domain) { @@ -693,8 +756,6 @@ IterDomain::IterDomain( "Cannot create an iter domain with a start that is not an int but received ", start, " ."); - - name_ = fusion_->registerVal(this); } IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) @@ -729,6 +790,22 @@ bool IterDomain::sameAs(const Statement* other) const { return is_same; } +// Returns a new IterDomain matching properties of this +IterDomain* IterDomain::clone() const { + auto cloned = IrBuilder::create( + ir_container_, + start(), + extent(), + stopOffset(), + getParallelType(), + getIterType(), + isRFactorProduct()); + + cloned->is_padded_dimension_ = is_padded_dimension_; + cloned->padded_to_size_ = padded_to_size_; + return cloned; +} + std::vector IterDomain::clone( const std::vector& domains) { std::vector cloned_domains; @@ -781,14 +858,15 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { itype = IterType::Iteration; } - IterDomain* merged_id = new IterDomain( - new Int(0), + IterDomain* merged_id = IrBuilder::create( + outer->container(), + outer->container()->zeroVal(), merged_id_size->as(), outer->getParallelType(), itype, outer->isRFactorProduct() || inner->isRFactorProduct()); - new Merge(merged_id, outer, inner); + IrBuilder::create(outer->container(), merged_id, outer, inner); return merged_id; } @@ -811,7 +889,8 @@ std::pair IterDomain::split( if (factor->getValType() == ValType::Scalar) { TORCH_CHECK( factor->isConstScalar() || - FusionGuard::getCurFusion()->hasInput(factor), + (FusionGuard::getCurFusion() == factor->fusion() && + factor->isFusionInput()), factor, " is not a constant nor an input. It must be one or the other to be used in a split.", " If you want a symbolic split based on a thread dimension please use IterDomain::split(IterDomain*, ParallelType);"); @@ -832,24 +911,33 @@ std::pair IterDomain::split( in->definition() == nullptr, "Partial split is only allowed with root domains"); } - // outer loop IterDomain - IterDomain* ido = new IterDomain( - new Int(0), + IterDomain* ido = IrBuilder::create( + in->container(), + in->container()->zeroVal(), inner_split ? remainder->as() : factor, in->getParallelType(), in->getIterType(), in->isRFactorProduct()); // inner loop IterDomain - IterDomain* idi = new IterDomain( - new Int(0), + IterDomain* idi = IrBuilder::create( + in->container(), + in->container()->zeroVal(), inner_split ? factor : remainder->as(), in->getParallelType(), in->getIterType(), in->isRFactorProduct()); - new Split(ido, idi, in, factor, inner_split, start_offset, stop_offset); + IrBuilder::create( + in->container(), + ido, + idi, + in, + factor, + inner_split, + start_offset, + stop_offset); return {ido, idi}; } @@ -864,7 +952,9 @@ std::pair IterDomain::split( } std::pair IterDomain::stridedSplit(int factor) { - auto split_out = IterDomain::split(this, new Int(factor), true); + // Use partial split so that only valid values are retained + auto split_out = IterDomain::split( + this, IrBuilder::create(container(), factor), true, true); split_out.second->iter_type_ = IterType::Stride; split_out.first->is_rfactor_domain_ = true; @@ -907,9 +997,10 @@ Val* IterDomain::stop() const { } TensorDomain::TensorDomain( + IrBuilderPasskey passkey, std::vector root_domain, std::vector contiguity) - : Val(ValType::TensorDomain, DataType::Null, false), + : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), contiguity_( contiguity.empty() ? std::vector(root_domain_.size(), false) @@ -925,14 +1016,14 @@ TensorDomain::TensorDomain( has_nontrivial_reduction_ = false; domain_ = root_domain_; resetDomains(); - name_ = fusion_->registerVal(this); } TensorDomain::TensorDomain( + IrBuilderPasskey passkey, std::vector root_domain, std::vector domain, std::vector contiguity) - : Val(ValType::TensorDomain, DataType::Null, false), + : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), domain_(std::move(domain)), contiguity_( @@ -963,15 +1054,15 @@ TensorDomain::TensorDomain( // Just due to clang-tidy, correct value set in resetDomains has_nontrivial_reduction_ = false; resetDomains(); - name_ = fusion_->registerVal(this); } TensorDomain::TensorDomain( + IrBuilderPasskey passkey, std::vector root_domain, std::vector rfactor_domain, std::vector domain, std::vector contiguity) - : Val(ValType::TensorDomain, DataType::Null, false), + : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), domain_(std::move(domain)), rfactor_domain_(std::move(rfactor_domain)), @@ -1013,7 +1104,6 @@ TensorDomain::TensorDomain( // Just due to clang-tidy, correct value set in resetDomains has_nontrivial_reduction_ = false; resetDomains(); - name_ = fusion_->registerVal(this); } TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) @@ -1026,6 +1116,30 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) contiguity_(src->contiguity()), has_nontrivial_reduction_(src->has_nontrivial_reduction_) {} +namespace { +std::vector lowerIterDomains( + const std::vector& domains) { + std::vector lowered_domains; + lowered_domains.reserve(domains.size()); + for (const auto iter_domain : domains) { + lowered_domains.push_back(iter_domain); + } + return lowered_domains; +}; +} // namespace + +bool TensorDomain::hasBlockBroadcast() const { + return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { + return id->isBroadcast() && id->isThreadDim(); + }); +} + +bool TensorDomain::hasGridBroadcast() const { + return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { + return id->isBroadcast() && id->isBlockDim(); + }); +} + bool TensorDomain::operator==(const TensorDomain& other) const { // Checks equality of each class field. Should not be necessary to // check no_bcast_domain_ and no_reduction_domain_ as they are just @@ -1389,6 +1503,7 @@ std::pair TensorDomain::rFactor( } Split::Split( + IrBuilderPasskey passkey, IterDomain* outer, IterDomain* inner, IterDomain* in, @@ -1396,14 +1511,18 @@ Split::Split( bool inner_split, Val* start_offset, Val* stop_offset) - : Expr(ExprType::Split), + : Expr(passkey, ExprType::Split), outer_{outer}, inner_{inner}, in_{in}, factor_{factor}, inner_split_{inner_split}, - start_offset_{start_offset != nullptr ? start_offset : new Int(0)}, - stop_offset_{stop_offset != nullptr ? stop_offset : new Int(0)} { + start_offset_{ + start_offset != nullptr ? start_offset + : passkey.ir_container_->zeroVal()}, + stop_offset_{ + stop_offset != nullptr ? stop_offset + : passkey.ir_container_->zeroVal()} { TORCH_INTERNAL_ASSERT( factor_->isAnInt(), "Attempted to create a Split node with a non-integer factor."); @@ -1412,7 +1531,6 @@ Split::Split( addInput(in); // TODO add factor as an input, need to check Split::Split during validation // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } Split::Split(const Split* src, IrCloner* ir_cloner) @@ -1453,12 +1571,15 @@ bool Split::sameAs(const Statement* other) const { stopOffset()->sameAs(other->as()->stopOffset()); } -Merge::Merge(IterDomain* out, IterDomain* outer, IterDomain* inner) - : Expr(ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} { +Merge::Merge( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* outer, + IterDomain* inner) + : Expr(passkey, ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} { addOutput(out); addInput(outer); addInput(inner); - name_ = FusionGuard::getCurFusion()->registerExpr(this); } Merge::Merge(const Merge* src, IrCloner* ir_cloner) @@ -1477,6 +1598,12 @@ bool Merge::sameAs(const Statement* other) const { return Expr::sameAs(other); } +NamedScalar::NamedScalar( + IrBuilderPasskey passkey, + std::string name, + DataType dtype) + : Val(passkey, ValType::NamedScalar, dtype), name_(std::move(name)) {} + NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner) : Val(src, ir_cloner), name_(src->name_) {} @@ -1495,13 +1622,15 @@ NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { isParallelTypeThread(p_type), "Cannot get parallel dim of non thread type, received: ", p_type); + TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); std::string parallel_dim = stringifyThreadSize(p_type); - return new NamedScalar(parallel_dim, DataType::Int); + return IrBuilder::create(parallel_dim, DataType::Int); } NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) { + TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); std::string parallel_ind = stringifyThread(p_type); - return new NamedScalar(parallel_ind, DataType::Int); + return IrBuilder::create(parallel_ind, DataType::Int); } c10::optional NamedScalar::getParallelDim() const { diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index a2c14386147..91d07b76b80 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 5bf05b0f516..004cfa23dff 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -140,7 +141,8 @@ struct SubstituteInExpr : public OptInDispatch { reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in(); auto out = reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out(); - expr_ = new UnaryOp(unary_expr->getUnaryOpType(), out, in); + expr_ = IrBuilder::create( + unary_expr->container(), unary_expr->getUnaryOpType(), out, in); } void handle(BinaryOp* binary_expr) final { @@ -151,7 +153,12 @@ struct SubstituteInExpr : public OptInDispatch { auto out = reference_->sameAs(binary_expr->out()) ? substitute_ : binary_expr->out(); - expr_ = new BinaryOp(binary_expr->getBinaryOpType(), out, lhs, rhs); + expr_ = IrBuilder::create( + binary_expr->container(), + binary_expr->getBinaryOpType(), + out, + lhs, + rhs); } void handle(TernaryOp* ternary_expr) final { @@ -163,7 +170,13 @@ struct SubstituteInExpr : public OptInDispatch { : ternary_expr->in3(); auto out = reference_->sameAs(ternary_expr->out()) ? substitute_ : ternary_expr->out(); - expr_ = new TernaryOp(ternary_expr->getTernaryOpType(), out, in1, in2, in3); + expr_ = IrBuilder::create( + ternary_expr->container(), + ternary_expr->getTernaryOpType(), + out, + in1, + in2, + in3); } void handle(ReductionOp* reduction_expr) final { @@ -176,8 +189,12 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(reduction_expr->in()) ? substitute_ : reduction_expr->in(); - expr_ = - new ReductionOp(reduction_expr->getReductionOpType(), init, out, in); + expr_ = IrBuilder::create( + reduction_expr->container(), + reduction_expr->getReductionOpType(), + init, + out, + in); } void handle(BroadcastOp* broadcast_expr) final { @@ -187,7 +204,11 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_ : broadcast_expr->in(); - expr_ = new BroadcastOp(out, in, broadcast_expr->getBroadcastDimFlags()); + expr_ = IrBuilder::create( + broadcast_expr->container(), + out, + in, + broadcast_expr->getBroadcastDimFlags()); } void handle(TransposeOp* transpose_expr) final { @@ -201,7 +222,8 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(transpose_expr->in()) ? substitute_->as() : transpose_expr->in(); - expr_ = new TransposeOp(out, in, transpose_expr->new2old()); + expr_ = IrBuilder::create( + transpose_expr->container(), out, in, transpose_expr->new2old()); } void handle(ShiftOp* shift_expr) final { @@ -210,7 +232,12 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in(); - expr_ = new ShiftOp(out, in, shift_expr->offsets(), shift_expr->pad()); + expr_ = IrBuilder::create( + shift_expr->container(), + out, + in, + shift_expr->offsets(), + shift_expr->padWidth()); } void handle(GatherOp* gather_expr) final { @@ -219,8 +246,12 @@ struct SubstituteInExpr : public OptInDispatch { auto in = reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in(); - expr_ = new GatherOp( - out, in, gather_expr->windowShape(), gather_expr->padWidth()); + expr_ = IrBuilder::create( + gather_expr->container(), + out, + in, + gather_expr->windowShape(), + gather_expr->padWidth()); } void handle(ViewOp* view_expr) final { @@ -234,7 +265,7 @@ struct SubstituteInExpr : public OptInDispatch { auto out = reference_->sameAs(view_expr->out()) ? substitute_->as() : view_expr->out(); - expr_ = new ViewOp(out, in); + expr_ = IrBuilder::create(view_expr->container(), out, in); } void handle(WelfordOp* welford_expr) final { @@ -268,7 +299,8 @@ struct SubstituteInExpr : public OptInDispatch { welford_expr->initN() && reference_->sameAs(welford_expr->initN()) ? substitute_ : welford_expr->initN(); - expr_ = new WelfordOp( + expr_ = IrBuilder::create( + welford_expr->container(), out_avg, out_var, out_N, @@ -402,13 +434,31 @@ std::vector allTvs(Fusion* fusion) { return uniqueEntries({used_tvs.begin(), used_tvs.end()}); } -std::vector historyOf(TensorDomain* td) { - return ExprSort::getExprs( - td->fusion(), {td->domain().begin(), td->domain().end()}); -} - -std::vector historyOf(TensorView* tv) { - return historyOf(tv->domain()); +std::vector getReductionOps(Fusion* fusion) { + std::vector red_ops; + for (auto expr : fusion->exprs()) { + const Val* out_val = nullptr; + if (expr->isA()) { + out_val = expr->as()->out(); + } else if (expr->isA()) { + out_val = expr->as()->outAvg(); + } else { + continue; + } + if (out_val == nullptr || !out_val->isA()) { + continue; + } + auto out_tv = out_val->as(); + if (std::any_of( + out_tv->getRootDomain().begin(), + out_tv->getRootDomain().end(), + [](IterDomain* id) { + return id->isReduction() && !id->isTrivialReduction(); + })) { + red_ops.push_back(expr); + } + } + return red_ops; } } // namespace ir_utils diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index c8dc2e6f679..1bf3f27ec0b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -110,6 +110,9 @@ auto filterByType(InputIt first, InputIt last) { return FilteredView(first, last); } +template +auto filterByType(const ContainerType&& inputs) = delete; + template auto filterByType(const ContainerType& inputs) { return filterByType(inputs.cbegin(), inputs.cend()); @@ -175,11 +178,7 @@ TORCH_CUDA_CU_API std::vector outputTvsOf( // returns all tensor views in fusion that are used between outputs and inputs. TORCH_CUDA_CU_API std::vector allTvs(Fusion* fusion); -// Returns the history of expressions applied to the domains of td -TORCH_CUDA_CU_API std::vector historyOf(TensorDomain* td); - -// Returns the history of expressions applied to the domains of tv -TORCH_CUDA_CU_API std::vector historyOf(TensorView* tv); +TORCH_CUDA_CU_API std::vector getReductionOps(Fusion* fusion); } // namespace ir_utils } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 344df98f5a7..894b40f79e3 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace torch { @@ -31,21 +32,94 @@ void remove_visited( } } +// Return all dependencies of a node including members of the node. +class RecursiveDependencies : public OptInDispatch { + public: + static std::vector next(Statement* stmt) { + RecursiveDependencies find_next(stmt); + return find_next.next_stmts_; + } + + private: + RecursiveDependencies() = default; + + RecursiveDependencies(Statement* stmt) { + handle(stmt); + } + + using OptInDispatch::handle; + + void handle(Expr* expr) final { + FusionGuard::getCurFusion()->assertInContainer( + expr, + "IterVisitor.cpp::RecursiveDependencies::handle(Expr*) Cannot traverse expr, "); + next_stmts_.insert( + next_stmts_.end(), expr->inputs().begin(), expr->inputs().end()); + } + + void handle(Val* val) final { + FusionGuard::getCurFusion()->assertInContainer( + val, + "IterVisitor.cpp::RecursiveDependencies::handle(Val*) Cannot traverse val, "); + OptInDispatch::handle(val); + } + + void simpleVal(Val* val) { + if (val->definition() == nullptr) { + return; + } + next_stmts_.push_back(val->definition()); + } + + void handle(Bool* stmt) final { + simpleVal(stmt); + } + + void handle(Double* stmt) final { + simpleVal(stmt); + } + + void handle(Int* stmt) final { + simpleVal(stmt); + } + + void handle(NamedScalar* stmt) final { + simpleVal(stmt); + } + + void handle(IterDomain* stmt) final { + next_stmts_.push_back(stmt->start()); + next_stmts_.push_back(stmt->extent()); + next_stmts_.push_back(stmt->stopOffset()); + simpleVal(stmt); + } + + void handle(TensorDomain* stmt) final { + next_stmts_.insert( + next_stmts_.end(), stmt->domain().begin(), stmt->domain().end()); + simpleVal(stmt); + } + + void handle(TensorView* tv) final { + next_stmts_.push_back(tv->domain()); + simpleVal(tv); + } + + std::vector next_stmts_; +}; + } // namespace std::vector IterVisitor::next(Statement* stmt) { if (stmt->isVal()) { return next(stmt->as()); - } else if (stmt->isExpr()) { - return next(stmt->as()); } else { - TORCH_INTERNAL_ASSERT( - false, "IterVisitor could not detect type in next_dispatch."); + return next(stmt->as()); } } std::vector IterVisitor::next(Val* v) { - FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, "); + FusionGuard::getCurFusion()->assertInContainer(v, "Cannot traverse val, "); if (v->definition() != nullptr) { return {v->definition()}; } @@ -53,7 +127,8 @@ std::vector IterVisitor::next(Val* v) { } std::vector IterVisitor::next(Expr* expr) { - FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, "); + FusionGuard::getCurFusion()->assertInContainer( + expr, "Cannot traverse expr, "); std::vector next_stmts{ expr->inputs().begin(), expr->inputs().end()}; return next_stmts; @@ -93,7 +168,8 @@ void IterVisitor::handle(Val* v) { void IterVisitor::traverseFrom( Fusion* fusion, const std::vector& from, - bool traverseAllPaths) { + bool traverseAllPaths, + bool traverseIntoMembers) { FusionGuard fg(fusion); std::unordered_set visited; @@ -137,7 +213,8 @@ void IterVisitor::traverseFrom( } else { // We're not ready to process this node, so add all its inputs to be // checked Visit input nodes. - auto next_stmts = next(stmt); + auto next_stmts = + traverseIntoMembers ? RecursiveDependencies::next(stmt) : next(stmt); // We may want to retraverse nodes, in that case revisit everything! if (!traverseAllPaths) { // If we don't want to retraverse, remove nodes we already visisted. @@ -308,7 +385,7 @@ void BackwardVisitor::traverseFrom( auto vals = AllVals::get(fusion, from); - auto exprs = ExprSort::getExprs(fusion, from); + auto exprs = StmtSort::getExprs(fusion, from); { size_t pos = 0; @@ -704,22 +781,41 @@ std::unordered_set DependencyCheck::getAllDependentVals( return DependentVals::getAllDependentVals(of); } -void ExprSort::handle(Expr* expr) { - exprs.push_back(expr); +void StmtSort::handle(Statement* stmt) { + stmts.push_back(stmt); } -std::vector ExprSort::getExprs(Fusion* fusion) { - ExprSort es; - es.traverse(fusion); - return es.exprs; +std::vector StmtSort::getExprs(Fusion* fusion, bool traverse_members) { + auto terminating_outputs = fusion->getTerminatingOutputs(); + return StmtSort::getExprs(fusion, terminating_outputs, traverse_members); } -std::vector ExprSort::getExprs( +std::vector StmtSort::getExprs( Fusion* fusion, - const std::vector& from) { - ExprSort es; - es.traverseFrom(fusion, from, false); - return es.exprs; + const std::vector& from, + bool traverse_members) { + StmtSort es; + es.traverseFrom(fusion, from, false, traverse_members); + auto stmts = StmtSort::getStmts(fusion, from, traverse_members); + auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); + std::vector exprs(filter.begin(), filter.end()); + return exprs; +} + +std::vector StmtSort::getStmts( + Fusion* fusion, + bool traverse_members) { + auto terminating_outputs = fusion->getTerminatingOutputs(); + return StmtSort::getStmts(fusion, terminating_outputs, traverse_members); +} + +std::vector StmtSort::getStmts( + Fusion* fusion, + const std::vector& from, + bool traverse_members) { + StmtSort es; + es.traverseFrom(fusion, from, false, traverse_members); + return es.stmts; } void InputsOf::handle(Val* v) { diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index d4aa56ea2fe..2447933d737 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -83,18 +83,21 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { void traverseHelper(Fusion* fusion, bool traverse_all_paths = false); public: - // Starts at nodes provided in from, traverses from these nodes to inputs. - // Calls handle on all Statement*s in topological sorted order. - // traverseAllPaths = false only call handle on each Statement* once - // traverseAllPaths = true traverses all paths from nodes in from to inputs. - // Handle on a Statement* for every path from "from" nodes, to inputs. - // to argument allows specification of nodes to stop at if we want to stop - // beffore we hit all leaf nodes. This can be helpful when we want to traverse - // from TensorView::domain(), to the rfactor domain, instead of root domain. + //! Starts at nodes provided in from, traverses from these nodes to inputs. + //! Calls handle on all Statement*s in topological sorted order. + //! \param traverseAllPaths = false only call handle on each Statement* once + //! traverseAllPaths = true traverses all paths from nodes in from to + //! inputs. Calls handle on a Statement* for every path from "from" nodes, + //! to inputs. + //! \param traverseIntoMembers = When hitting nodes like TensorView, + //! TensorDomain, or IterDomain where there are members of the nodes that are + //! Val's a value of "true" will also traverse into those member Val's, a + //! value of "false" will not traverse into the members. void traverseFrom( Fusion* fusion, const std::vector& from, - bool traverseAllPaths = false); + bool traverseAllPaths = false, + bool traverseIntoMembers = false); // Iterates from terminating outputs registered with the fusion. Terminating // means value is not used to generate any other value used in producing @@ -246,18 +249,40 @@ class TORCH_CUDA_CU_API DependencyCheck { // Expr sort will take a fusion and return a topologically sorted list of // expressions. -class ExprSort : public IterVisitor { +class StmtSort : public IterVisitor { protected: - std::vector exprs; + std::vector stmts; - void handle(Expr* expr) override; + void handle(Statement* stmt) override; public: - static std::vector getExprs(Fusion* fusion); + // If traverse_members it will also extract all member nodes in the sorted + // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc + static std::vector getExprs( + Fusion* fusion, + bool traverse_members = false); + // If traverse_members it will also extract all member nodes in the sorted + // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc static std::vector getExprs( Fusion* fusion, - const std::vector& from); + const std::vector& from, + bool traverse_members = false); + + // If traverse_members it will also extract all member nodes in the sorted + // statement list in the fusion. i.e. all IterDomains, extents, and associated + // expressions of them + static std::vector getStmts( + Fusion* fusion, + bool traverse_members = false); + + // If traverse_members it will also extract all member nodes in the sorted + // expr list in the fusion. i.e. all IterDomains, extents, and associated + // expressions of them + static std::vector getStmts( + Fusion* fusion, + const std::vector& from, + bool traverse_members = false); }; class InputsOf : public IterVisitor { diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index d3ef9eeb95d..b9062f5bc45 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -1,7 +1,8 @@ #include +#include #include #include -#include +#include #include #include @@ -11,22 +12,24 @@ namespace torch { namespace jit { namespace fuser { namespace cuda { + +IrBuilderPasskey::IrBuilderPasskey(IrContainer* ir_container) + : ir_container_(ir_container) {} + namespace kir { namespace { //! Scan all primary expressions in the Kernel IR and build //! lists of specialized nodes and other interesting information -class KernelIrScanner : private kir::IrVisitor { +class KernelIrScanner : private IrVisitor { public: explicit KernelIrScanner(const Kernel* kernel) { - for (const auto& ir_node : kernel->irNodes()) { - ir_node->accept(this); - } + IrVisitor::handle(kernel->topLevelExprs()); const auto gpu_lower = GpuLower::current(); for (auto split : gpu_lower->nonDivisibleSplitInfo().splitsToValidate()) { - auto extent = gpu_lower->lowerValue(split->in()->extent()); - auto factor = gpu_lower->lowerValue(split->factor()); + auto extent = split->in()->extent(); + auto factor = split->factor(); summary_.splits_to_validate.emplace_back(extent, factor); } } @@ -36,7 +39,17 @@ class KernelIrScanner : private kir::IrVisitor { } private: - void visit(const kir::Sync* sync) final { + using IrVisitor::handle; + void handle(Expr* expr) final { + IrVisitor::handle(expr); + for (auto inp : expr->inputs()) { + handle(inp); + } + for (auto out : expr->outputs()) { + handle(out); + } + } + void handle(Sync* sync) final { // TODO: Move to a dedicated validation pass // which is not on the common execution/compilation path if (sync->isWarHazardSync()) { @@ -44,7 +57,7 @@ class KernelIrScanner : private kir::IrVisitor { } } - void visit(const kir::Allocate* allocate) final { + void handle(Allocate* allocate) final { switch (allocate->memoryType()) { case MemoryType::Global: summary_.global_allocations.push_back(allocate); @@ -65,28 +78,23 @@ class KernelIrScanner : private kir::IrVisitor { } } - void visit(const kir::UnaryOp* unary_op) final { - if (unary_op->operation() == UnaryOpType::RandLike) { + void handle(UnaryOp* unary_op) final { + if (unary_op->getUnaryOpType() == UnaryOpType::RandLike) { // This kernel is using random numbers summary_.is_stochastic = true; } } - void visit(const kir::TensorIndex* tensor_index) final { + void handle(TensorIndex* tensor_index) final { const auto tv = tensor_index->view(); const auto domain = tv->domain(); - // Do we have any reductions? summary_.has_block_reductions = summary_.has_block_reductions || domain->hasBlockReduction(); - // Do we have block broadcasts? - summary_.has_block_broadcasts = - summary_.has_block_broadcasts || domain->hasBlockBroadcast(); - // Update the largest smem data type if (domain->hasBlockReduction() || domain->hasGridReduction() || - tv->memoryType() == MemoryType::Shared) { + tv->getMemoryType() == MemoryType::Shared) { const auto data_type = tv->dtype(); const size_t type_size = dataTypeSize(data_type); if (type_size > max_smem_type_size_) { @@ -94,38 +102,50 @@ class KernelIrScanner : private kir::IrVisitor { summary_.largest_smem_data_type = data_type; } } + } - // Update Welford - if (tensor_index->definition() != nullptr && - tensor_index->definition()->isA()) { - summary_.has_welford = true; - summary_.has_block_welford = - summary_.has_block_welford || domain->hasBlockReduction(); - summary_.has_grid_welford = - summary_.has_grid_welford || domain->hasGridReduction(); - } + void handle(WelfordOp* welford_op) final { + summary_.has_welford = true; + TORCH_INTERNAL_ASSERT(welford_op->outAvg()->isA()); + auto out_dom = welford_op->outAvg()->as()->view()->domain(); + summary_.has_block_welford = + summary_.has_block_welford || out_dom->hasBlockReduction(); } - void visit(const kir::GridWelford* grid_welford) final { - const auto dom = grid_welford->welford_op() - ->out() - ->as() - ->view() - ->domain(); + void handle(GridWelford* grid_welford) final { + summary_.has_welford = true; + summary_.has_grid_welford = true; + const auto dom = + grid_welford->welford_op()->out()->as()->view()->domain(); updateGridReductionInLoop(dom); } - void visit(const kir::GridReduction* grid_reduction) final { + void handle(GridReduction* grid_reduction) final { + summary_.has_grid_reductions = true; const auto dom = grid_reduction->reduction_op() ->out() - ->as() + ->as() ->view() ->domain(); updateGridReductionInLoop(dom); } - void visit(const kir::GridBroadcast*) final { + void handle(GridBroadcast* grid_broadcast) final { summary_.has_cooperative_grid_reduction = true; + handle(grid_broadcast->broadcast_op()); + } + + void handle(BroadcastOp* bop) final { + const ParallelTypeBitmap parallel_types = + GpuLower::current()->threadPredMap().getParallelBroadcastDomains( + bop->out()->as()->view()); + summary_.broadcast_parallel_types.emplace(bop, parallel_types); + // Do we have block broadcasts? + summary_.has_block_broadcasts = + summary_.has_block_broadcasts || parallel_types.hasTID(); + // Do we have grid broadcasts? + summary_.has_grid_broadcasts = + summary_.has_grid_broadcasts || parallel_types.hasBID(); } private: @@ -136,10 +156,9 @@ class KernelIrScanner : private kir::IrVisitor { void updateGridReductionInLoop(TensorDomain* dom) { summary_.has_grid_reductions = true; - const auto gpu_lower = GpuLower::current(); for (const auto i : c10::irange(dom->nDims())) { - const auto id = - gpu_lower->caParallelMap().getConcreteMappedID(dom->domain()[i]); + const auto id = GpuLower::current()->caParallelMap().getConcreteMappedID( + dom->domain()[i]); summary_.has_cooperative_grid_reduction = summary_.has_cooperative_grid_reduction || @@ -169,7 +188,7 @@ class KernelIrScanner : private kir::IrVisitor { //! MemoryType::Global for tensors parallelized with blockIdx), it is //! assumed that allocation is properly extended for the iteration //! count. -class ValidateAllocation : private kir::IrVisitor { +class ValidateAllocation : private OptOutConstDispatch { public: static void validate(const Kernel* kernel) { ValidateAllocation validate_allocation(kernel); @@ -178,14 +197,14 @@ class ValidateAllocation : private kir::IrVisitor { private: explicit ValidateAllocation(const Kernel* kernel) { live_allocations_.emplace_back(std::vector()); - for (const auto& ir_node : kernel->topLevelExprs()) { - ir_node->accept(this); + for (const auto& expr : kernel->topLevelExprs()) { + OptOutConstDispatch::handle(expr); } live_allocations_.pop_back(); TORCH_INTERNAL_ASSERT(live_allocations_.empty()); } - void visit(const kir::Allocate* allocate) final { + void handle(const Allocate* allocate) final { TORCH_INTERNAL_ASSERT(!live_allocations_.empty()); live_allocations_.back().push_back(allocate); } @@ -195,53 +214,52 @@ class ValidateAllocation : private kir::IrVisitor { // during in the allocation lowering if it's thread-parallel and not // allocated on shared or global memories, or if it's block-parallel // ando not allocated on global memory. - void validate(const kir::ForLoop* for_loop) { + void validate(const ForLoop* for_loop) { const auto loop_id = for_loop->iter_domain(); - const auto gpu_lower = GpuLower::current(); for (const auto& allocations : live_allocations_) { for (const auto& allocate : allocations) { - const auto tv = dynamic_cast(allocate->buffer()); + const auto tv = dynamic_cast(allocate->buffer()); if (tv == nullptr) { continue; } for (const auto& axis : tv->domain()->domain()) { - if (!gpu_lower->caParallelMap().areMapped(loop_id, axis)) { + if (!GpuLower::current()->caParallelMap().areMapped(loop_id, axis)) { continue; } - if (isParallelTypeThreadDim(loop_id->parallelType())) { + if (isParallelTypeThreadDim(loop_id->getParallelType())) { TORCH_INTERNAL_ASSERT( - tv->memoryType() == MemoryType::Shared || - tv->memoryType() == MemoryType::Global, + tv->getMemoryType() == MemoryType::Shared || + tv->getMemoryType() == MemoryType::Global, "Tensor t", tv->name(), " must be allocated on SMEM or GMEM."); - } else if (isParallelTypeBlockDim(loop_id->parallelType())) { - TORCH_INTERNAL_ASSERT(tv->memoryType() == MemoryType::Global); + } else if (isParallelTypeBlockDim(loop_id->getParallelType())) { + TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global); } } } } } - void visit(const kir::ForLoop* for_loop) final { + void handle(const ForLoop* for_loop) final { if (for_loop->stop() != for_loop->iter_domain()->extent() && - isParallelTypeThread(for_loop->iter_domain()->parallelType())) { + isParallelTypeThread(for_loop->iter_domain()->getParallelType())) { validate(for_loop); } live_allocations_.emplace_back(std::vector()); for (const auto& expr : for_loop->body().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } live_allocations_.pop_back(); } - void visit(const kir::IfThenElse* ite) final { + void handle(const IfThenElse* ite) final { for (const auto& expr : ite->thenBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } for (const auto& expr : ite->elseBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } @@ -252,11 +270,9 @@ class ValidateAllocation : private kir::IrVisitor { } // namespace // TODO(kir): Kernel IR validation -void Kernel::finalize(std::vector top_level_exprs) { - TORCH_CHECK(top_level_exprs_.empty()); +void Kernel::finalize(std::vector top_level_exprs) { + TORCH_INTERNAL_ASSERT(top_level_exprs_.empty()); top_level_exprs_ = std::move(top_level_exprs); - predicate_map_ = std::make_unique( - GpuLower::current()->threadPredMap()); warp_padded_parallel_info_ = GpuLower::current()->getWarpPaddedParallelInfo(); ValidateAllocation::validate(this); analyze(); @@ -270,8 +286,63 @@ void Kernel::analyze() { } void Kernel::print() const { - kir::IrPrinter ir_printer(std::cout); - ir_printer.printKernel(this); + IrPrinter ir_printer(std::cout); + ir_printer.handle(this); +} + +//! Register the Val with this fusion +void Kernel::registerVal(Val* val) { + if (inContainer(val)) { + return; + } + if (val->kernel()) { + TORCH_CHECK( + val->kernel() == this, + val->toString(), + " was not found in the active kernel."); + } + + Fusion::registerVal(val); +} + +//! Register expr with this fusion. +//! When we register an expression, we want to update the dependency tracking +//! of Vals. We add expr to our general expr_set_, +void Kernel::registerExpr(Expr* expr) { + if (inContainer(expr)) { + return; + } + + if (expr->kernel()) { + TORCH_CHECK( + expr->kernel() == this, + expr->toString(), + " was not found in the active kernel."); + } + + for (Val* input : expr->inputs()) { + TORCH_INTERNAL_ASSERT( + inContainer(input), + "Input\n", + input->toString(), + " to expr,\n", + expr->toString(), + ",\n is invalid because it is not in the same kernel."); + } + + for (Val* output : expr->outputs()) { + TORCH_INTERNAL_ASSERT( + inContainer(output), + "Output\n", + output->toString(), + " to expr,\n", + expr->toString(), + ",\n is invalid because it is not in the same kernel."); + } + + // Register expr is explicitly non-SSA when coming from a kernel. This is + // detected inside Fusion::registerExpr + Fusion::registerExpr(expr); } } // namespace kir diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h index b273324e1e2..0c8bbdef9df 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.h +++ b/torch/csrc/jit/codegen/cuda/kernel.h @@ -1,12 +1,15 @@ #pragma once -#include -#include -#include +#include + +#include +#include +#include #include #include #include +#include #include #include @@ -47,6 +50,9 @@ struct KernelSummary { //! Do we have any block broadcasts? bool has_block_broadcasts = false; + //! Do we have any grid broadcasts? + bool has_grid_broadcasts = false; + //! Do we have any welford op? bool has_welford = false; @@ -67,87 +73,47 @@ struct KernelSummary { std::vector dynamic_lmem_allocations; //! ceilDiv extents that must be divisible - std::vector> splits_to_validate; + std::vector> splits_to_validate; + + //! Effective ParallelTypes of broadcast ops + std::unordered_map + broadcast_parallel_types; }; //! Container for a lowered Kernel IR //! -//! TODO(kir): currently, it is just pointing to nodes owned -//! by a Fusion object. The goal is to have the Kernel object -//! own the Kernel IR nodes -//! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API Kernel final : public NonCopyable { +class TORCH_CUDA_CU_API Kernel final : public Fusion { public: - Kernel() = default; + // Kernel starts by grabbing all the nodes from the provided fusion. + // Kernel is not SSA, if a definition is not set, we should update it, but + // not remove previous definition if it is set. This is primarily because when + // we do something like generate an initialization statement for a reduction + // TV, we may want to continue to do fusion like analysis on the original + // expression. + Kernel(Fusion* fusion) : Fusion(*fusion) {} + + Kernel() = delete; + + // No move or copy semantics + Kernel(const Kernel&) = delete; + Kernel& operator=(const Kernel&) = delete; //! Finalize a kernel definition //! //! At this point we have a complete kernel definition and we can //! run analysis passes to build a KernelSummary //! - void finalize(std::vector top_level_exprs); - - //! Register input as an input of the kernel - void addInput(Val* input) { - inputs_.push_back(input); - input_set_.insert(input); - } + void finalize(std::vector top_level_exprs); - //! Register output as an output of the kernel - void addOutput(Val* output) { - outputs_.push_back(output); - output_set_.insert(output); - } - - const auto& inputs() const { - return inputs_; - } - - const auto& outputs() const { - return outputs_; - } - - bool isInput(Val* val) const { - return input_set_.find(val) != input_set_.end(); - } - - bool isOutput(Val* val) const { - return output_set_.find(val) != output_set_.end(); - } - - const auto& topLevelExprs() const { + const std::vector& topLevelExprs() const { return top_level_exprs_; } - const auto& irNodes() const { - return ir_nodes_; - } - const KernelSummary& summary() const { return summary_; } - const ThreadPredicateMap& predicateMap() const { - return *predicate_map_; - } - - //! Register a new Kernel IR node - //! - //! \note This is a specialized helper for kir::IrBuilder, not - //! intendted for general use - //! - void registerIrNode(kir::Passkey passkey, std::unique_ptr node) { - TORCH_CHECK(passkey.kernel == this); - ir_nodes_.push_back(std::move(node)); - } - - //! Allocates a new value identifier - kir::ValueId newValueId(kir::Passkey passkey) { - TORCH_CHECK(passkey.kernel == this); - return next_value_id_++; - } - //! Checks if parallel type is padded bool isParallelTypePadded(ParallelType ptype) const { return ptype == ParallelType::TIDx && @@ -161,32 +127,26 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable { //! Debug dump of the Kernel IR void print() const; + protected: + //! Register the Val with this fusion + void registerVal(Val* val) override; + + //! Register expr with this fusion. + //! When we register an expression, we want to update the dependency tracking + //! of Vals. We add expr to our general expr_set_, + void registerExpr(Expr* expr) override; + private: // Analyze the kernel IR and caches the summary of interesting data void analyze(); private: - // Kernel IR nodes - std::vector> ir_nodes_; - // Top level statements - std::vector top_level_exprs_; - - // Kernel inputs and outputs - std::vector inputs_; - std::vector outputs_; - std::unordered_set input_set_; - std::unordered_set output_set_; - - // Used to allocate unique value IDs - kir::ValueId next_value_id_ = 1; + std::vector top_level_exprs_; // Summary of interesting kernel data KernelSummary summary_; - // Predicate map - // TODO(kir): consider a simpler, kernel IR based version - std::unique_ptr predicate_map_; WarpPaddedParallelInfo warp_padded_parallel_info_; }; diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index 39350876bd2..c1c113dbbc4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace torch { namespace jit { @@ -25,6 +26,10 @@ int getCommonDeviceCUDA(const at::ArrayRef& inputs) { continue; } const auto& device = input.toTensor().device(); + // skip cpu scalar tensor as they'll be promoted to scalar later + if (device.is_cpu() && is_cpu_scalar(input.toTensor())) { + continue; + } TORCH_CHECK(device.is_cuda(), "nvfuser only supports cuda device"); auto cur_index = device.index(); if (index != -1 && index != cur_index) { @@ -202,9 +207,9 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } // Access kernels associated with the common device id - auto dev_id = getCommonDeviceCUDA(inputs); - TORCH_INTERNAL_ASSERT(dev_id >= 0); - auto& kernel_runtimes = kernel_runtimes_[dev_id]; + auto device_index = getCommonDeviceCUDA(inputs); + TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs"); + auto& kernel_runtimes = kernel_runtimes_[device_index]; // Check for re-use hit case // a kernel runtime is re-usable if all the compiled @@ -277,14 +282,6 @@ FusionKernelRuntime::FusionKernelRuntime( } else { auto complete_fusion_heuristic = maybe_complete_fusion_heuristic.value(); - // Translate welfords if apply - if (fusion_copy->hasWelford()) { - bool translated = SegmentCandidateFinder::TranslateWelfordInFusion( - fusion_copy.get(), inputs); - if (translated) { - complete_fusion_heuristic = ScheduleHeuristic::Persistent; - } - } // Take ownership of the transformed fusion single_kernel_fusion_ = std::move(fusion_copy); @@ -358,7 +355,7 @@ std::vector FusionKernelRuntime::runKernelWithInput( launch_params = scheduler_entry->pointwiseParams().lparams; } executors_[group_id].compileFusion( - fusion_to_run.get(), options, inputs, launch_params); + fusion_to_run.get(), inputs, launch_params, options); } else { // Load launch params for reduction and normalization kernels if (scheduler_entry->hasReductionParam()) { @@ -453,6 +450,7 @@ std::vector FusionKernelRuntime::runWithInput( " inputs but expecting ", segmented_fusion_->inputs().size()); + c10::Device device(c10::DeviceType::CUDA, 0); int extent_index_ = 0; // Bind input in the tensor_map for (const auto i : c10::irange(inputs.size())) { @@ -466,6 +464,7 @@ std::vector FusionKernelRuntime::runWithInput( // more convenient and safer than replication if (inputs[i].isTensor()) { auto aten_tensor = inputs[i].toTensor(); + device = aten_tensor.device(); for (auto dim_size : aten_tensor.sizes()) { runtime_workspace_.tensor_map.emplace( runtime_workspace_.group_extent_binding_order[extent_index_++], @@ -504,14 +503,30 @@ std::vector FusionKernelRuntime::runWithInput( if (iter != runtime_workspace_.tensor_map.end()) { fusion_outputs.push_back(iter->second); } else { + bool empty_type_check = output->getDataType().has_value() && + output->getDataType().value() == DataType::Float; + + // Only support two cases of empty tensor here, since + // this is hot path. + auto out_tv = output->as(); + + // TODO: should be only one of the two once the "empty" + // definition has been unified throughout the ops. + bool empty_tensor_check = + out_tv->isZeroDim() || out_tv->isEmptyTensor(); + // This is the check for an empty tensor; TORCH_INTERNAL_ASSERT( - output->as()->nDims() == 0 && - output->getDataType().has_value() && - output->getDataType().value() == DataType::Float, + empty_tensor_check && empty_type_check, "Non empty tensor cannot be found at tensor_map in ", __FUNCTION__); - fusion_outputs.emplace_back(at::Tensor()); + + // TODO: would need to clean up this part when + // we have a unified and consistent way to generate + // size-0 tensors. + const auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(device); + fusion_outputs.emplace_back(at::empty({0}, tensor_options)); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index ae84c25e4f2..cba42f99dc4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -7,8 +7,8 @@ #include #include +#include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp index 7421d2e235a..3605f7a4155 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp @@ -1,7 +1,6 @@ #include #include -#include #include @@ -16,11 +15,11 @@ void ExpressionEvaluator::bind( Int::ScalarType concrete_value) { TORCH_CHECK(value->isScalar()); TORCH_CHECK(value->dtype() == DataType::Int); - TORCH_CHECK(!value->isConst(), "Tried to bind to a constant value"); + TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value"); TORCH_CHECK( value->definition() == nullptr, "Tried to bind to a value that is computed in the kernel IR: ", - toString(value), + value->toString(), " with ", concrete_value); known_values_[value] = concrete_value; @@ -41,14 +40,18 @@ void ExpressionEvaluator::bind( c10::optional ExpressionEvaluator::evaluate(const Val* value) { if (precomputed_integers_ && precomputed_integers_->ready()) { - return precomputed_integers_->getMaybeValueFor(value); - } else if (value->isScalar() && value->isConst()) { + if (precomputed_integers_->getMaybeValueFor(value).has_value()) { + return precomputed_integers_->getMaybeValueFor(value); + } + } + + if (value->isScalar() && value->isConst()) { return value->as()->value(); } else { FUSER_PERF_SCOPE("kir::ExpressionEvaluator::evaluate"); - TORCH_CHECK(value->isScalar()); - TORCH_CHECK(value->dtype() == DataType::Int); + TORCH_CHECK(value->isScalar(), value->toString()); + TORCH_CHECK(value->dtype() == DataType::Int, value->toString()); // Is the value known (either explicit binding or memoized)? const auto pre_eval_it = known_values_.find(value); @@ -56,7 +59,7 @@ c10::optional ExpressionEvaluator::evaluate(const Val* value) { return pre_eval_it->second; } - value->accept(this); + OptOutConstDispatch::handle(value); const auto post_eval_it = known_values_.find(value); return post_eval_it != known_values_.end() @@ -74,24 +77,23 @@ void ExpressionEvaluator::print() const { std::cout << "\nEvaluation context\n"; std::cout << "--------------------\n"; for (const auto& kv : known_values_) { - std::cout << toString(kv.first) << " = " << kv.second << "\n"; + std::cout << kv.first->toString() << " = " << kv.second << "\n"; + } + std::cout << "\nPre-computed Values\n"; + if (precomputed_integers_ != nullptr) { + precomputed_integers_->print(); } std::cout << "--------------------\n\n"; } -void ExpressionEvaluator::unhandled(const void*) { - TORCH_INTERNAL_ASSERT( - false, "Kernel IR expression evaluation reached an unsupported node"); -} - -void ExpressionEvaluator::visit(const Int* value) { +void ExpressionEvaluator::handle(const Int* value) { TORCH_INTERNAL_ASSERT(!value->isConst()); if (auto def = value->definition()) { - def->accept(this); + OptOutConstDispatch::handle(def); } } -void ExpressionEvaluator::visit(const NamedScalar* named_scalar) { +void ExpressionEvaluator::handle(const NamedScalar* named_scalar) { const auto& name = named_scalar->name(); for (auto pt : kParallelTypeThreads) { auto pt_val_it = known_parallel_dimensions_.find(pt); @@ -105,10 +107,10 @@ void ExpressionEvaluator::visit(const NamedScalar* named_scalar) { } } -void ExpressionEvaluator::visit(const UnaryOp* unary_op) { +void ExpressionEvaluator::handle(const UnaryOp* unary_op) { const auto in = evaluate(unary_op->in()); if (in.has_value()) { - switch (unary_op->operation()) { + switch (unary_op->getUnaryOpType()) { case UnaryOpType::Neg: known_values_[unary_op->out()] = -*in; break; @@ -121,11 +123,11 @@ void ExpressionEvaluator::visit(const UnaryOp* unary_op) { } } -void ExpressionEvaluator::visit(const BinaryOp* binary_op) { +void ExpressionEvaluator::handle(const BinaryOp* binary_op) { const auto lhs = evaluate(binary_op->lhs()); const auto rhs = evaluate(binary_op->rhs()); if (lhs.has_value() && rhs.has_value()) { - switch (binary_op->operation()) { + switch (binary_op->getBinaryOpType()) { case BinaryOpType::Add: known_values_[binary_op->out()] = *lhs + *rhs; break; diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h index 64791387543..63586857ad8 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h @@ -1,7 +1,9 @@ #pragma once -#include +#include + +#include #include #include @@ -34,7 +36,7 @@ namespace kir { //! } //! ``` //! -class TORCH_CUDA_CU_API ExpressionEvaluator : private IrVisitor { +class TORCH_CUDA_CU_API ExpressionEvaluator : private OptInConstDispatch { public: //! Set a concrete value for a symbolic value void bind(const Val* value, Int::ScalarType concrete_value); @@ -56,11 +58,10 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private IrVisitor { } private: - void unhandled(const void*) final; - void visit(const Int* value) final; - void visit(const NamedScalar* named_scalar) final; - void visit(const UnaryOp* unary_op) final; - void visit(const BinaryOp* binary_op) final; + void handle(const Int* value) final; + void handle(const NamedScalar* named_scalar) final; + void handle(const UnaryOp* unary_op) final; + void handle(const BinaryOp* binary_op) final; private: std::unordered_map known_values_; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index eebfd41729c..5d2eb44f8a8 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -1,8 +1,7 @@ +#include #include #include #include -#include -#include #include #include #include @@ -15,369 +14,52 @@ namespace fuser { namespace cuda { namespace kir { -void Node::print() const { - std::cout << "\n"; - IrPrinter(std::cout).printNode(this); - std::cout << "\n"; -} - -Val::Val(Passkey passkey, DataType dtype) : Node(passkey), dtype_(dtype) { - // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48534 - id_ = passkey.kernel->newValueId(passkey); -} - -namespace { - -// Traverse definition of all values involved in constructing the provided val. -// Check if all values involved are constant values, meaning the provided -// val is also a constant value. -class ConstCheck : IrVisitor { - private: - bool is_const_ = true; - - using IrVisitor::visit; - - void visit(const Bool* b) override { - is_const_ = is_const_ && b->isConst(); - } - - void visit(const Double* d) override { - is_const_ = is_const_ && d->isConst(); - } - - void visit(const Int* i) override { - is_const_ = is_const_ && i->isConst(); - } - - void visit(const NamedScalar* ns) override { - is_const_ = is_const_ && false; - } - - void visit(const Expr* expr) { - for (auto inp : expr->inputs()) { - visit(inp); - } - } - - void visit(const Val* val) { - if (val->definition() != nullptr) { - visit(val->definition()); - } else { - val->accept(this); - } - } - - public: - static bool isConst(const Val* val) { - ConstCheck cc; - cc.visit(val); - return cc.is_const_; - } -}; - -} // namespace - -bool Val::isConstScalar() const { - if (!isScalar()) - return false; - return ConstCheck::isConst(this); -} - -Expr* Expr::parentScope() const { - if (scope()) { - return scope()->owner(); - } else { - return nullptr; - } -} - -NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) { - std::string parallel_dim = stringifyThreadSize(p_type); - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - return ir_builder.create(parallel_dim, DataType::Int); -} - -NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) { - std::string parallel_ind = stringifyThread(p_type); - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - return ir_builder.create(parallel_ind, DataType::Int); -} - -c10::optional NamedScalar::getParallelDim() const { - if (stringifyThreadSize(ParallelType::TIDx).compare(name()) == 0) { - return c10::optional(ParallelType::TIDx); - } else if (stringifyThreadSize(ParallelType::TIDy).compare(name()) == 0) { - return c10::optional(ParallelType::TIDy); - } else if (stringifyThreadSize(ParallelType::TIDz).compare(name()) == 0) { - return c10::optional(ParallelType::TIDz); - } else if (stringifyThreadSize(ParallelType::BIDx).compare(name()) == 0) { - return c10::optional(ParallelType::BIDx); - } else if (stringifyThreadSize(ParallelType::BIDy).compare(name()) == 0) { - return c10::optional(ParallelType::BIDy); - } else if (stringifyThreadSize(ParallelType::BIDz).compare(name()) == 0) { - return c10::optional(ParallelType::BIDz); - } - return c10::nullopt; -} - -c10::optional NamedScalar::getParallelIndex() const { - if (stringifyThread(ParallelType::TIDx).compare(name()) == 0) { - return c10::optional(ParallelType::TIDx); - } else if (stringifyThread(ParallelType::TIDy).compare(name()) == 0) { - return c10::optional(ParallelType::TIDy); - } else if (stringifyThread(ParallelType::TIDz).compare(name()) == 0) { - return c10::optional(ParallelType::TIDz); - } else if (stringifyThread(ParallelType::BIDx).compare(name()) == 0) { - return c10::optional(ParallelType::BIDx); - } else if (stringifyThread(ParallelType::BIDy).compare(name()) == 0) { - return c10::optional(ParallelType::BIDy); - } else if (stringifyThread(ParallelType::BIDz).compare(name()) == 0) { - return c10::optional(ParallelType::BIDz); - } - return c10::nullopt; -} - -IterDomain::IterDomain(Passkey passkey, Val* start, Val* extent) - : Val(passkey, DataType::Int), - start_(start), - stop_(extent), - extent_(extent) {} - -IterDomain::IterDomain( - Passkey passkey, - const fuser::cuda::IterDomain* iter_domain) - : Val(passkey, iter_domain->getDataType().value()), - start_(GpuLower::current()->lowerValue(iter_domain->start())), - stop_(GpuLower::current()->lowerValue(iter_domain->stop())), - extent_(GpuLower::current()->lowerValue(iter_domain->extent())), - parallel_type_(iter_domain->getParallelType()), - iter_type_(iter_domain->getIterType()), - is_rfactor_domain_(iter_domain->isRFactorProduct()), - is_simple_(iter_domain->definition() == nullptr), - is_padded_dimension_(iter_domain->hasPaddingToMultipleOfWarp()) { - // preserve the fusion node's name - setName(iter_domain->name()); -} - -//! Note that the parallel dimension, if available, may be different -//! from the actual extent of this IterDomain as the parallel -//! dimension is determined by the largest extent of IterDomains -//! sharing the same loop. -Val* IterDomain::extent() const { - TORCH_INTERNAL_ASSERT(extent_ != nullptr); - return extent_; -} - -TensorDomain::TensorDomain(Passkey passkey, std::vector domain) - : Val(passkey, DataType::Null), root_domain_(std::move(domain)) { - domain_ = root_domain_; - resetDomains(); -} - -TensorDomain::TensorDomain( - Passkey passkey, - const fuser::cuda::TensorDomain* tensor_domain) - : Val(passkey, DataType::Null), contiguity_(tensor_domain->contiguity()) { - // preserve the fusion node's name - setName(tensor_domain->name()); - - const auto lowerIterDomains = - [](const std::vector& domains) { - std::vector lowered_domains; - lowered_domains.reserve(domains.size()); - for (const auto iter_domain : domains) { - lowered_domains.push_back( - GpuLower::current()->lowerValue(iter_domain)->as()); - } - return lowered_domains; - }; - - root_domain_ = lowerIterDomains(tensor_domain->getRootDomain()); - domain_ = lowerIterDomains(tensor_domain->domain()); - no_bcast_domain_ = lowerIterDomains(tensor_domain->noBroadcasts()); - no_reduction_domain_ = lowerIterDomains(tensor_domain->noReductions()); - rfactor_domain_ = lowerIterDomains(tensor_domain->getRFactorDomain()); -} - -bool TensorDomain::hasReduction() const { - return no_reduction_domain_.size() != domain_.size(); -} - -bool TensorDomain::hasBlockReduction() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isReduction() && id->isThreadDim(); - }); -} - -bool TensorDomain::hasGridReduction() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isReduction() && id->isBlockDim(); - }); -} - -bool TensorDomain::hasBlockBroadcast() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isBroadcast() && id->isThreadDim(); - }); -} - -bool TensorDomain::hasGridBroadcast() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isBroadcast() && id->isBlockDim(); - }); -} - -bool TensorDomain::hasBroadcast() const { - return no_bcast_domain_.size() != domain_.size(); -} - -bool TensorDomain::hasRFactor() const { - return !rfactor_domain_.empty(); -} - -bool TensorDomain::hasVectorize() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->parallelType() == ParallelType::Vectorize || - id->parallelType() == ParallelType::MisalignedVectorize; - }); -} - -IterDomain* TensorDomain::axis(int i) const { - TORCH_INTERNAL_ASSERT(i >= 0 && i < int(domain_.size())); - return domain_[i]; -} - -std::vector TensorDomain::noReductions( - const std::vector& td) { - std::vector no_reduction_domains; - for (auto id : td) { - if (!id->isReduction()) { - no_reduction_domains.push_back(id); - } - } - return no_reduction_domains; -} - -std::vector TensorDomain::noBroadcasts( - const std::vector& td) { - std::vector no_broadcast_domains; - for (auto id : td) { - if (!id->isBroadcast()) { - no_broadcast_domains.push_back(id); - } - } - return no_broadcast_domains; -} - -TensorView::TensorView(Passkey passkey, const fuser::cuda::TensorView* tv) - : Val(passkey, tv->getDataType().value()), fuser_tv_(tv) { - setName(tv->name()); - domain_ = GpuLower::current()->lowerValue(tv->domain())->as(); - memory_type_ = tv->getMemoryType(); -} - -TensorView::TensorView( - Passkey passkey, - DataType dtype, - TensorDomain* domain, - MemoryType memory_type) - : Val(passkey, dtype), domain_(domain), memory_type_(memory_type) {} - -UnaryOp::UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in) - : Expr(passkey), operation_(operation), out_(out), in_(in) { - addOutput(out); - addInput(in); -} - -BinaryOp::BinaryOp( - Passkey passkey, - BinaryOpType operation, - Val* out, - Val* lhs, - Val* rhs) - : Expr(passkey), operation_(operation), out_(out), lhs_(lhs), rhs_(rhs) { - addOutput(out); - addInput(lhs); - addInput(rhs); -} - -TernaryOp::TernaryOp( - Passkey passkey, - TernaryOpType operation, - Val* out, - Val* in1, - Val* in2, - Val* in3) - : Expr(passkey), - operation_(operation), - out_(out), - in1_(in1), - in2_(in2), - in3_(in3) { - addOutput(out); - addInput(in1); - addInput(in2); - addInput(in3); -} - -ReductionOp::ReductionOp( - Passkey passkey, - BinaryOpType operation, - Val* init, - Val* out, - Val* in) - : Expr(passkey), operation_(operation), init_(init), out_(out), in_(in) { - addOutput(out); - addInput(in); +Predicate::Predicate( + IrBuilderPasskey passkey, + PredicateType ptype, + const Expr* expr, + Bool* thread_pred) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(ptype), + expr_(expr), + thread_pred_(thread_pred) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT( + ptype != PredicateType::Unswitch && ptype != PredicateType::Manual); } -WelfordOp::WelfordOp( - Passkey passkey, - Val* out_var, - Val* out_avg, - Val* out_N, - Val* init_var, - Val* init_avg, - Val* init_N, - Val* in_var, - Val* in_avg, - Val* in_N) - : Expr(passkey), - out_var_(out_var), - out_avg_(out_avg), - out_N_(out_N), - init_var_(init_var), - init_avg_(init_avg), - init_N_(init_N), - in_var_(in_var), - in_avg_(in_avg), - in_N_(in_N) { - addOutput(out_avg); - addOutput(out_var); - addOutput(out_N); - - if (!in_N->isOneInt()) { - addInput(in_var); - } - addInput(in_avg); - addInput(in_N); +Predicate::Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(PredicateType::Unswitch), + unrolled_loop_(unrolled_loop) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr); } -BroadcastOp::BroadcastOp(Passkey passkey, Val* out, Val* in) - : Expr(passkey), out_(out), in_(in) { - TORCH_CHECK(in->isA() || in->isA()); - TORCH_CHECK(out->isA() || out->isA()); - addOutput(out); - addInput(in); +Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(PredicateType::Manual), + value_(value) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT(value != nullptr); } TensorIndex::TensorIndex( - Passkey passkey, - const fuser::cuda::TensorView* view, + IrBuilderPasskey passkey, + const TensorView* view, std::vector indices) - : Val(passkey, view->getDataType().value()), - view_(GpuLower::current()->lowerValue(view)->as()), + : Val(passkey, ValType::TensorIndex, view->getDataType().value()), + view_(view), indices_(indices) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); TORCH_INTERNAL_ASSERT( std::all_of( indices.begin(), @@ -392,20 +74,33 @@ TensorIndex::TensorIndex( indices_.end()); // If indices becomes empty, just put one ZeroInt if (indices_.empty()) { - indices_.push_back(kir::IrBuilder(GpuLower::current()->kernel()).zeroVal()); + indices_.push_back(FusionGuard::getCurFusion()->zeroVal()); } } -Sync::Sync(Passkey passkey, bool war_sync) - : Expr(passkey), war_sync_(war_sync) {} +Sync::Sync(IrBuilderPasskey passkey, bool war_sync) + : Expr(passkey, ExprType::Sync), war_sync_(war_sync) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} -InitMagicZero::InitMagicZero(Passkey passkey) : Expr(passkey) {} +InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) + : Expr(passkey, ExprType::InitMagicZero) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} -UpdateMagicZero::UpdateMagicZero(Passkey passkey) : Expr(passkey) {} +UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) + : Expr(passkey, ExprType::UpdateMagicZero) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); - expr->setScope(this); } void Scope::insert_before(Expr* ref, Expr* expr) { @@ -440,11 +135,6 @@ void Scope::insert(size_t pos, Expr* expr) { void Scope::erase(std::vector::const_iterator pos) { // Remove the scope of the expr if this is the scope auto expr = *pos; - TORCH_INTERNAL_ASSERT( - expr->scope() == this, - "Inconsistent scoping of expression detected: ", - kir::toString(expr)); - expr->setScope(nullptr); exprs_.erase(pos); } @@ -470,7 +160,7 @@ void Scope::clear() { } ForLoop::ForLoop( - Passkey passkey, + IrBuilderPasskey passkey, IterDomain* iter_domain, Val* index, Val* start, @@ -479,7 +169,7 @@ ForLoop::ForLoop( bool vectorize, Val* vectorize_shift, bool unroll_required) - : Expr(passkey), + : Expr(passkey, ExprType::ForLoop), iter_domain_{iter_domain}, index_(index), start_(start), @@ -489,43 +179,42 @@ ForLoop::ForLoop( vectorize_shift_(vectorize_shift), unroll_required_(unroll_required), body_(this) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int); addInput(index); addInput(iter_domain); if (start_ == nullptr && iter_domain->isThread()) { - start_ = - IrBuilder(GpuLower::current()->kernel()) - .create( - stringifyThread(iter_domain->parallelType()), DataType::Int); + start_ = NamedScalar::getParallelIndex(iter_domain->getParallelType()); } if (step_ == nullptr) { if (iter_domain->isThread()) { - step_ = IrBuilder(GpuLower::current()->kernel()) - .create( - stringifyThreadSize(iter_domain->parallelType()), - DataType::Int); + step_ = NamedScalar::getParallelDim(iter_domain->getParallelType()); } else { - step_ = IrBuilder(GpuLower::current()->kernel()).oneVal(); + step_ = FusionGuard::getCurFusion()->oneVal(); } } } -ForLoop::ForLoop(Passkey passkey, IterDomain* iter_domain) +ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) : ForLoop( passkey, iter_domain, - iter_domain->isBroadcast() - ? IrBuilder(GpuLower::current()->kernel()).zeroVal() - : IrBuilder(GpuLower::current()->kernel()) - .create(c10::nullopt), + iter_domain->isBroadcast() ? FusionGuard::getCurFusion()->zeroVal() + : IrBuilder::create(c10::nullopt), nullptr, nullptr, nullptr, - isParallelTypeVectorize(iter_domain->parallelType()), + isParallelTypeVectorize(iter_domain->getParallelType()), nullptr, - false) {} + false) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} -ForLoop::ForLoop(Passkey passkey, const ForLoop* other) +ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) : ForLoop( passkey, other->iter_domain(), @@ -535,7 +224,11 @@ ForLoop::ForLoop(Passkey passkey, const ForLoop* other) other->step(), other->vectorize(), other->vectorize_shift(), - other->isUnrollRequired()) {} + other->isUnrollRequired()) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast @@ -550,7 +243,7 @@ bool ForLoop::isUnrolled() const { if (isUnrollRequired() && !isUnrollable()) { TORCH_WARN( "Unroll required but not possible. Register allocation disabled. Loop index: ", - kir::toString(index_)); + index_->toString()); return false; } @@ -570,7 +263,7 @@ bool ForLoop::isUnrolled() const { } // Unrolling is technically possible but avoided - if (iter_domain()->parallelType() == ParallelType::Unswitch) { + if (iter_domain()->getParallelType() == ParallelType::Unswitch) { // Use ParallelType::Unroll if unrolling is desired. Note that // unswitched size-one loops are not unrolled as they are not // materialized as actual for-loops. @@ -605,8 +298,8 @@ Val* ForLoop::step() const { return step_; } -IfThenElse::IfThenElse(Passkey passkey, Predicate* cond) - : Expr(passkey), then_body_(this), else_body_(this) { +IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) + : Expr(passkey, ExprType::IfThenElse), then_body_(this), else_body_(this) { setPredicate(cond); addInput(cond); } @@ -621,17 +314,19 @@ Val* TensorIndex::index(int i) const { } Allocate::Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, std::vector shape, bool zero_init) - : Expr(passkey), + : Expr(passkey, ExprType::Allocate), buffer_(buffer), memory_type_(memory_type), shape_(std::move(shape)), zero_init_(zero_init) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); if (!shape_.empty()) { TORCH_INTERNAL_ASSERT( (shape_.size() == 1 && shape_[0]->isOneInt()) || @@ -639,7 +334,7 @@ Allocate::Allocate( } else { TORCH_INTERNAL_ASSERT(buffer_->isA()); TORCH_INTERNAL_ASSERT( - buffer_->as()->memoryType() == memory_type_); + buffer_->as()->getMemoryType() == memory_type_); const auto domain = buffer_->as()->domain(); for (auto axis : domain->noReductions()) { shape_.push_back(axis->extent()); @@ -650,19 +345,19 @@ Allocate::Allocate( if (size_ == nullptr) { size_ = s; } else { - size_ = ir_builder.mulExpr(size_, s); + size_ = IrBuilder::mulExpr(size_, s); } } if (size_ == nullptr) { - size_ = ir_builder.oneVal(); + size_ = FusionGuard::getCurFusion()->oneVal(); } addInput(size_); } Allocate::Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, Val* size, @@ -672,31 +367,57 @@ Allocate::Allocate( buffer, memory_type, size == nullptr ? std::vector{} : std::vector{size}, - zero_init) {} + zero_init) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} GridReduction::GridReduction( - Passkey passkey, + IrBuilderPasskey passkey, ReductionOp* reduction_op, Allocate* reduction_buffer, Allocate* sync_buffer) - : Expr(passkey), + : Expr(passkey, ExprType::GridReduction), reduction_op_(reduction_op), reduction_buffer_(reduction_buffer), - sync_buffer_(sync_buffer) {} + sync_buffer_(sync_buffer) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +GridBroadcast::GridBroadcast( + IrBuilderPasskey passkey, + BroadcastOp* broadcast_op, + Allocate* broadcast_buffer, + Allocate* sync_buffer) + : Expr(passkey, ExprType::GridBroadcast), + broadcast_op_(broadcast_op), + broadcast_buffer_(broadcast_buffer), + sync_buffer_(sync_buffer) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} GridWelford::GridWelford( - Passkey passkey, + IrBuilderPasskey passkey, WelfordOp* welford_op, Allocate* var_buffer, Allocate* avg_buffer, Allocate* n_buffer, Allocate* sync_buffer) - : Expr(passkey), + : Expr(passkey, ExprType::GridWelford), welford_op_(welford_op), var_buffer_(var_buffer), avg_buffer_(avg_buffer), n_buffer_(n_buffer), - sync_buffer_(sync_buffer) {} + sync_buffer_(sync_buffer) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} } // namespace kir } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index c1ac6052783..ad6be90bf98 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -1,16 +1,13 @@ #pragma once -#include -#include - -// TODO(kir): remove these once the Kernel IR is separated from Fusion IR +#include #include -#include -#include #include +#include +#include +#include #include -#include #include #include @@ -21,26 +18,22 @@ namespace torch { namespace jit { namespace fuser { namespace cuda { -namespace kir { -class IrBuilder; -class Kernel; +class IrBuilderPasskey; // Abstract nodes -class Node; class Val; class Expr; // Values -class NamedScalar; -class Predicate; class Bool; class Double; class Int; +class NamedScalar; + class IterDomain; class TensorDomain; class TensorView; -class TensorIndex; // Expressions class UnaryOp; @@ -50,7 +43,14 @@ class ReductionOp; class WelfordOp; class BroadcastOp; -// Statements +namespace kir { +class Kernel; + +// Values +class Predicate; +class TensorIndex; + +// Expressions class Allocate; class Sync; class InitMagicZero; @@ -64,443 +64,17 @@ class GridWelford; // Expr container class Scope; -using ValueId = int32_t; - -//! Token used to restrict the access to Kernel IR creation -//! -//! A token is associated with a kernel, which is passed with the key -//! (Passkey::kernel) -//! -//! It is a "granular friendship" token, used to implement the "passkey" idiom: -//! https://www.spiria.com/en/blog/desktop-software/passkey-idiom-and-better-friendship-c -//! https://arne-mertz.de/2016/10/passkey-idiom -//! -class Passkey { - friend class IrBuilder; - - public: - Kernel* const kernel = nullptr; - - private: - explicit Passkey(Kernel* kernel) : kernel(kernel) {} -}; - -//! Kernel IR visitor interface -class TORCH_CUDA_CU_API IrVisitor : public PolymorphicBase { - public: - // TODO(kir): use Node* instead of void* - virtual void unhandled(const void* node) {} - - // Values - virtual void visit(const NamedScalar* named_scalar) { - unhandled(named_scalar); - } - virtual void visit(const Predicate* value) { - unhandled(value); - } - virtual void visit(const Bool* value) { - unhandled(value); - } - virtual void visit(const Double* value) { - unhandled(value); - } - virtual void visit(const Int* value) { - unhandled(value); - } - virtual void visit(const IterDomain* iter_domain) { - unhandled(iter_domain); - } - virtual void visit(const TensorDomain* tensor_domain) { - unhandled(tensor_domain); - } - virtual void visit(const TensorView* tensor_view) { - unhandled(tensor_view); - } - virtual void visit(const TensorIndex* tensor_index) { - unhandled(tensor_index); - } - - // Expressions - virtual void visit(const UnaryOp* node) { - unhandled(node); - } - virtual void visit(const BinaryOp* node) { - unhandled(node); - } - virtual void visit(const TernaryOp* node) { - unhandled(node); - } - virtual void visit(const ReductionOp* node) { - unhandled(node); - } - virtual void visit(const WelfordOp* node) { - unhandled(node); - } - virtual void visit(const BroadcastOp* node) { - unhandled(node); - } - - // Statements - virtual void visit(const Allocate* node) { - unhandled(node); - } - virtual void visit(const Sync* node) { - unhandled(node); - } - virtual void visit(const InitMagicZero* node) { - unhandled(node); - } - virtual void visit(const UpdateMagicZero* node) { - unhandled(node); - } - virtual void visit(const ForLoop* node) { - unhandled(node); - } - virtual void visit(const IfThenElse* node) { - unhandled(node); - } - virtual void visit(const GridReduction* node) { - unhandled(node); - } - virtual void visit(const GridBroadcast* node) { - unhandled(node); - } - virtual void visit(const GridWelford* node) { - unhandled(node); - } -}; - -//! Kernel IR visitor interface -class TORCH_CUDA_CU_API MutableIrVisitor : public PolymorphicBase { - public: - // TODO(kir): use Node* instead of void* - virtual void unhandled(const void*) {} - - // Values - virtual void visit(NamedScalar* named_scalar) { - unhandled(named_scalar); - } - virtual void visit(Predicate* value) { - unhandled(value); - } - virtual void visit(Bool* value) { - unhandled(value); - } - virtual void visit(Double* value) { - unhandled(value); - } - virtual void visit(Int* value) { - unhandled(value); - } - virtual void visit(IterDomain* iter_domain) { - unhandled(iter_domain); - } - virtual void visit(TensorDomain* tensor_domain) { - unhandled(tensor_domain); - } - virtual void visit(TensorView* tensor_view) { - unhandled(tensor_view); - } - virtual void visit(TensorIndex* tensor_index) { - unhandled(tensor_index); - } - - // Expressions - virtual void visit(UnaryOp* node) { - unhandled(node); - } - virtual void visit(BinaryOp* node) { - unhandled(node); - } - virtual void visit(TernaryOp* node) { - unhandled(node); - } - virtual void visit(ReductionOp* node) { - unhandled(node); - } - virtual void visit(BroadcastOp* node) { - unhandled(node); - } - - virtual void visit(WelfordOp* node) { - unhandled(node); - } - - // Statements - virtual void visit(Allocate* node) { - unhandled(node); - } - virtual void visit(Sync* node) { - unhandled(node); - } - virtual void visit(InitMagicZero* node) { - unhandled(node); - } - virtual void visit(UpdateMagicZero* node) { - unhandled(node); - } - virtual void visit(ForLoop* node) { - unhandled(node); - } - virtual void visit(IfThenElse* node) { - unhandled(node); - } - virtual void visit(GridReduction* node) { - unhandled(node); - } - virtual void visit(GridBroadcast* node) { - unhandled(node); - } - virtual void visit(GridWelford* node) { - unhandled(node); - } -}; - -//! Base class for Kernel IR nodes -class TORCH_CUDA_CU_API Node : public NonCopyable, public PolymorphicBase { - public: - explicit Node(Passkey) {} - - //! IR Visitor double-dispatch interface - //! (https://en.wikipedia.org/wiki/Visitor_pattern) - virtual void accept(IrVisitor* visitor) const = 0; - - //! Non constant IR Visitor - virtual void accept(MutableIrVisitor* visitor) = 0; - - //! Debug helper, prints the textual representation of an IR node - void print() const; -}; - -//! Generic value (scalar or tensor) -class TORCH_CUDA_CU_API Val : public Node { - public: - Val(Passkey passkey, DataType dtype); - - // TODO(kir): consider renaming - StmtNameType name() const { - return name_; - } - - void setName(StmtNameType name) { - name_ = name; - } - - ValueId id() const { - return id_; - } - - DataType dtype() const { - return dtype_; - } - - Expr* definition() const { - return definition_; - } - - void setDefinition(Expr* expr) { - // TODO(kir): extra checks on changing existing definitions? - definition_ = expr; - } - - virtual bool isScalar() const { - return false; - } - - bool isConstScalar() const; - - virtual bool isConst() const { - return false; - } - - // TODO(kir): revisit and find a better interface - virtual bool isZeroInt() const { - return false; - } - - virtual bool isOneInt() const { - return false; - } - - void setEvaluatorIndex(int to) { - TORCH_INTERNAL_ASSERT(evaluator_index_ == -1); - evaluator_index_ = to; - } - - int evaluatorIndex() const { - return evaluator_index_; - } - - private: - const DataType dtype_; - - // The expression which defines this value, or nullptr - Expr* definition_ = nullptr; - - // This is a value name preserved from the Fusion IR (optional) - StmtNameType name_ = kInvalidStmName; - - // All Kernel IR values have IDs (unique within the same Kernel) - ValueId id_ = -1; - - // Expr evaluator idx; - int evaluator_index_ = -1; -}; - -//! Base class for expressions and statements -//! -//! Expressions consume inputs and produce outputs (depending on the context -//! this may imply assignments). Currently some of the expressions -//! don't actually produce any outputs (ForLoop, IfThenElse) and they -//! model statements to be executed. -//! -//! TODO(kir): split the expressions, assignments and statements? -//! -class TORCH_CUDA_CU_API Expr : public Node { - public: - explicit Expr(Passkey passkey) : Node(passkey) {} - - const auto& inputs() const { - return inputs_; - } - - const auto& outputs() const { - return outputs_; - } - - Scope* scope() const { - return scope_; - } - - //! Set the current scope - void setScope(Scope* scope) { - scope_ = scope; - } - - Expr* parentScope() const; - - Predicate* predicate() const { - return predicate_; - } - - void setPredicate(Predicate* predicate) { - predicate_ = predicate; - } - - Predicate* writePredicate() const { - return write_predicate_; - } - - void setWritePredicate(Predicate* write_predicate) { - write_predicate_ = write_predicate; - } - - protected: - // TODO(kir): try to avoid this protected interface - void addInput(Val* input) { - inputs_.push_back(input); - } - - void addOutput(Val* output) { - output->setDefinition(this); - outputs_.push_back(output); - } - - private: - // TODO(kir): can we avoid this? - std::vector inputs_; - std::vector outputs_; - - // TODO(kir): revisit scope/nesting data structures - Scope* scope_ = nullptr; - - Predicate* predicate_ = nullptr; - // Only used for reduction-related expressions - Predicate* write_predicate_ = nullptr; -}; - -class TORCH_CUDA_CU_API NamedScalar final : public Val { - public: - // NOLINTNEXTLINE(modernize-pass-by-value) - NamedScalar(Passkey passkey, std::string name, DataType dtype) - : Val(passkey, dtype), name_(name) {} - - explicit NamedScalar(Passkey passkey, const fuser::cuda::NamedScalar* node) - : Val(passkey, node->getDataType().value()) { - name_ = node->name(); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - // TODO(kir): this is hiding and redefining Val::name() - const std::string& name() const { - return name_; - } - - // Return the named scalar extent of a parallel dimension (e.g. blockDim.x) - static NamedScalar* getParallelDim(ParallelType p_type); - - // Return the named scalar index of a parallel dimension (e.g. threadIdx.x) - static NamedScalar* getParallelIndex(ParallelType p_type); - - // Return the parallel type of this NamedScalar if it is an extent of a - // parallel dimension - c10::optional getParallelDim() const; - - // Return the parallel type of this NamedScalar if it is an index of a - // parallel dimension - c10::optional getParallelIndex() const; - - private: - std::string name_; -}; - class TORCH_CUDA_CU_API Predicate final : public Val { public: explicit Predicate( - Passkey passkey, + IrBuilderPasskey passkey, PredicateType ptype, const Expr* expr = nullptr, - Bool* thread_pred = nullptr) - : Val(passkey, DataType::Bool), - ptype_(ptype), - expr_(expr), - thread_pred_(thread_pred) { - TORCH_INTERNAL_ASSERT( - ptype != PredicateType::Unswitch && ptype != PredicateType::Manual); - } + Bool* thread_pred = nullptr); - explicit Predicate(Passkey passkey, ForLoop* unrolled_loop) - : Val(passkey, DataType::Bool), - ptype_(PredicateType::Unswitch), - unrolled_loop_(unrolled_loop) { - TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr); - } - - explicit Predicate(Passkey passkey, Bool* value) - : Val(passkey, DataType::Bool), - ptype_(PredicateType::Manual), - value_(value) { - TORCH_INTERNAL_ASSERT(value != nullptr); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } + explicit Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop); - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit Predicate(IrBuilderPasskey passkey, Bool* value); PredicateType predicate_type() const { return ptype_; @@ -543,6 +117,10 @@ class TORCH_CUDA_CU_API Predicate final : public Val { value_ = value; } + bool isConst() const final { + return hasValue() && value_->isConst(); + } + private: PredicateType ptype_ = PredicateType::Manual; @@ -561,603 +139,13 @@ class TORCH_CUDA_CU_API Predicate final : public Val { Bool* value_ = nullptr; }; -class TORCH_CUDA_CU_API Bool final : public Val { - public: - explicit Bool(Passkey passkey, const c10::optional& value) - : Val(passkey, DataType::Bool), maybe_value_(value) {} - - explicit Bool(Passkey passkey, const fuser::cuda::Bool* node) - : Val(passkey, DataType::Bool), maybe_value_(node->value()) { - setName(node->name()); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - bool isConst() const override { - return maybe_value_.has_value(); - } - - c10::optional value() const { - return maybe_value_; - } - - private: - const c10::optional maybe_value_; -}; - -class TORCH_CUDA_CU_API Double final : public Val { - public: - using ScalarType = double; - - explicit Double(Passkey passkey, const c10::optional& value) - : Val(passkey, DataType::Double), maybe_value_(value) {} - - explicit Double(Passkey passkey, const fuser::cuda::Double* node) - : Val(passkey, DataType::Double), maybe_value_(node->value()) { - setName(node->name()); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - bool isConst() const override { - return maybe_value_.has_value(); - } - - c10::optional value() const { - return maybe_value_; - } - - private: - const c10::optional maybe_value_; -}; - -class TORCH_CUDA_CU_API Int final : public Val { - public: - using ScalarType = int64_t; - - explicit Int(Passkey passkey, const c10::optional& value) - : Val(passkey, DataType::Int), maybe_value_(value) {} - - // SFINAE constructor to avoid 0 constant pointer ambiguity - template < - typename T, - typename = typename std::enable_if< - std::is_pointer::value && - std::is_convertible::value>::type> - explicit Int(Passkey passkey, T node) - : Val(passkey, DataType::Int), maybe_value_(node->value()) { - setName(node->name()); - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isScalar() const override { - return true; - } - - bool isConst() const override { - return maybe_value_.has_value(); - } - - bool isZeroInt() const override { - return maybe_value_.has_value() && *maybe_value_ == 0; - } - - bool isOneInt() const override { - return maybe_value_.has_value() && *maybe_value_ == 1; - } - - c10::optional value() const { - return maybe_value_; - } - - private: - const c10::optional maybe_value_; -}; - -class TORCH_CUDA_CU_API IterDomain final : public Val { - public: - IterDomain(Passkey passkey, Val* start, Val* extent); - - explicit IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - bool isReduction() const { - return iterType() == IterType::Reduction; - } - - bool isRFactorProduct() const { - return is_rfactor_domain_; - } - - bool isBroadcast() const { - return iterType() == IterType::BroadcastWithStride || - iterType() == IterType::BroadcastWithoutStride; - } - - bool isGather() const { - return iterType() == IterType::Gather; - } - - bool isStride() const { - return iterType() == IterType::Stride; - } - - bool isParallelized() const { - return parallelType() != ParallelType::Serial; - } - - // Return if this iter domain is mapped to a grid dimension - bool isBlockDim() const { - return parallelType() == ParallelType::BIDz || - parallelType() == ParallelType::BIDy || - parallelType() == ParallelType::BIDx; - } - - // Return if this iter domain is mapped to a block dimension - bool isThreadDim() const { - return parallelType() == ParallelType::TIDz || - parallelType() == ParallelType::TIDy || - parallelType() == ParallelType::TIDx; - } - - // Return if this iter domain is either mapped to a block or grid dimension - bool isThread() const { - return isBlockDim() || isThreadDim(); - } - - ParallelType parallelType() const { - return parallel_type_; - } - - IterType iterType() const { - return iter_type_; - } - - Val* start() const { - return start_; - } - - Val* stop() const { - return stop_; - } - - Val* extent() const; - - bool isSimple() const { - return is_simple_; - } - - bool hasPaddingToMultipleOfWarp() const { - return is_padded_dimension_; - } - - private: - Val* const start_ = nullptr; - Val* const stop_ = nullptr; - Val* const extent_ = nullptr; - ParallelType parallel_type_ = ParallelType::Serial; - IterType iter_type_ = IterType::Iteration; - bool is_rfactor_domain_ = false; - - // An IterDomain is "simple" if the original Fusion IterDomain - // doesn't have a definition ("definition" expression) - // - // TODO(kir): this feels like a hack, revisit - // - bool is_simple_ = true; - - //! Indicates if this iterdomain is a padded parallel dimension - bool is_padded_dimension_ = false; -}; - -// TODO(kir): is this really a value? -class TORCH_CUDA_CU_API TensorDomain final : public Val { - public: - explicit TensorDomain(Passkey, std::vector domain); - - explicit TensorDomain( - Passkey passkey, - const fuser::cuda::TensorDomain* tensor_domain); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - std::vector::size_type nDims() const { - return domain_.size(); - } - - // TODO(kir): rename this - const std::vector& domain() const { - return domain_; - } - - const std::vector& contiguity() const { - return contiguity_; - } - - std::string getContiguityString() const { - std::stringstream ss; - for (auto b : contiguity()) { - ss << (b ? "t" : "f"); - } - return ss.str(); - } - - bool hasReduction() const; - bool hasBlockReduction() const; - bool hasGridReduction() const; - bool hasBlockBroadcast() const; - bool hasGridBroadcast() const; - bool hasBroadcast() const; - bool hasRFactor() const; - bool hasVectorize() const; - - const std::vector& noReductions() const { - return no_reduction_domain_; - } - - const std::vector& noBroadcasts() const { - return no_bcast_domain_; - } - - const std::vector& rootDomain() const { - return root_domain_; - }; - - const std::vector& rfactorDomain() const { - return rfactor_domain_; - }; - - void resetDomains() { - no_reduction_domain_ = noReductions(domain_); - no_bcast_domain_ = noBroadcasts(domain_); - } - - IterDomain* axis(int i) const; - - // TODO(kir): overloading non-static and static methods is not a good idea - static std::vector noReductions(const std::vector&); - static std::vector noBroadcasts(const std::vector&); - - private: - std::vector root_domain_; - std::vector domain_; - std::vector no_bcast_domain_; - std::vector no_reduction_domain_; - std::vector rfactor_domain_; - const std::vector contiguity_; -}; - -class TORCH_CUDA_CU_API TensorView final : public Val { - public: - explicit TensorView(Passkey, const fuser::cuda::TensorView* tv); - - TensorView( - Passkey, - DataType dtype, - TensorDomain* domain, - MemoryType memory_type); - - TensorDomain* domain() const { - return domain_; - } - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - MemoryType memoryType() const { - return memory_type_; - } - - fuser::cuda::TensorView* fuserTv() const { - TORCH_INTERNAL_ASSERT(fuser_tv_ != nullptr); - // TODO(kir): remove the need for const_cast - return const_cast(fuser_tv_); // NOLINT - } - - private: - TensorDomain* domain_ = nullptr; - MemoryType memory_type_ = MemoryType::Local; - - // TODO(kir): remove temporary hack - const fuser::cuda::TensorView* fuser_tv_ = nullptr; -}; - -class TORCH_CUDA_CU_API UnaryOp final : public Expr { - public: - UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in() const { - return in_; - } - - UnaryOpType operation() const { - return operation_; - } - - private: - const UnaryOpType operation_; - Val* const out_ = nullptr; - Val* const in_ = nullptr; -}; - -class TORCH_CUDA_CU_API BinaryOp final : public Expr { - public: - BinaryOp( - Passkey passkey, - BinaryOpType operation, - Val* out, - Val* lhs, - Val* rhs); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* lhs() const { - return lhs_; - } - - Val* rhs() const { - return rhs_; - } - - BinaryOpType operation() const { - return operation_; - } - - private: - const BinaryOpType operation_; - Val* const out_ = nullptr; - Val* const lhs_ = nullptr; - Val* const rhs_ = nullptr; -}; - -class TORCH_CUDA_CU_API TernaryOp final : public Expr { - public: - TernaryOp( - Passkey passkey, - TernaryOpType operation, - Val* out, - Val* in1, - Val* in2, - Val* in3); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in1() const { - return in1_; - } - - Val* in2() const { - return in2_; - } - - Val* in3() const { - return in3_; - } - - TernaryOpType operation() const { - return operation_; - } - - private: - const TernaryOpType operation_; - Val* const out_ = nullptr; - Val* const in1_ = nullptr; - Val* const in2_ = nullptr; - Val* const in3_ = nullptr; -}; - -class TORCH_CUDA_CU_API ReductionOp final : public Expr { - public: - ReductionOp( - Passkey passkey, - BinaryOpType operation, - Val* init, - Val* out, - Val* in); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in() const { - return in_; - } - - Val* init() const { - return init_; - } - - BinaryOpType operation() const { - return operation_; - } - - private: - const BinaryOpType operation_; - Val* const init_ = nullptr; - Val* const out_ = nullptr; - Val* const in_ = nullptr; -}; - -class TORCH_CUDA_CU_API WelfordOp final : public Expr { - public: - WelfordOp( - Passkey passkey, - Val* out_var, - Val* out_avg, - Val* out_N, - Val* init_var, - Val* init_avg, - Val* init_N, - Val* in_var, - Val* in_avg, - Val* in_N); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_avg_; - } - - Val* in() const { - return in_avg_; - } - - // Welford Specific accessors - // Almost wanted to add a new struct for {var, avg, N} - Val* outVar() const { - return out_var_; - } - - Val* outAvg() const { - return out_avg_; - } - - Val* outN() const { - return out_N_; - } - - Val* initVar() const { - return init_var_; - } - - Val* initAvg() const { - return init_avg_; - } - - Val* initN() const { - return init_N_; - } - - Val* inVar() const { - return in_var_; - } - - Val* inAvg() const { - return in_avg_; - } - - Val* inN() const { - return in_N_; - } - - private: - Val* const out_var_; - Val* const out_avg_; - Val* const out_N_; - Val* const init_var_; - Val* const init_avg_; - Val* const init_N_; - Val* const in_var_; - Val* const in_avg_; - Val* const in_N_; -}; - class TORCH_CUDA_CU_API TensorIndex final : public Val { public: TensorIndex( - Passkey, + IrBuilderPasskey, const fuser::cuda::TensorView* view, std::vector indices); - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - std::vector::size_type nDims() const { return indices_.size(); } @@ -1170,8 +158,7 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { TensorView* view() const { TORCH_INTERNAL_ASSERT(view_ != nullptr); - // TODO(kir): remove the need for const_cast - return const_cast(view_); // NOLINT + return const_cast(view_); // NOLINT } private: @@ -1179,46 +166,17 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { std::vector indices_; }; -class TORCH_CUDA_CU_API BroadcastOp final : public Expr { - public: - BroadcastOp(Passkey passkey, Val* out, Val* in); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - - Val* out() const { - return out_; - } - - Val* in() const { - return in_; - } - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; -}; - //! Allocate is a lower level Node that describes a buffer of memory that //! is required as an intermediate within a kernel. The extent is the expression //! of the size of the buffer that is generated from the TensorView that //! describes the output of an operation. -//! -//! TODO(kir): The components of Allocate like Type and Name could be separated -//! from the the assocated TensorView. Perhaps that is more appropriate? -//! class TORCH_CUDA_CU_API Allocate final : public Expr { public: //! Allocation of a multi-dimensional buffer //! //! param shape Size of each dimension explicit Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, std::vector shape = {}, @@ -1228,20 +186,12 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { //! //! param size Size of allocation explicit Allocate( - Passkey passkey, + IrBuilderPasskey passkey, Val* buffer, MemoryType memory_type, Val* size, bool zero_init = false); - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - Val* buffer() const { return buffer_; } @@ -1292,15 +242,7 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { // class TORCH_CUDA_CU_API Sync final : public Expr { public: - explicit Sync(Passkey passkey, bool war_sync = false); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit Sync(IrBuilderPasskey passkey, bool war_sync = false); bool isWarHazardSync() const { return war_sync_; @@ -1315,30 +257,14 @@ class TORCH_CUDA_CU_API Sync final : public Expr { // in helpers.cu class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: - explicit InitMagicZero(Passkey passkey); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit InitMagicZero(IrBuilderPasskey passkey); }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero // in helpers.cu class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: - explicit UpdateMagicZero(Passkey passkey); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit UpdateMagicZero(IrBuilderPasskey passkey); }; // TODO(kir): promote to IR node @@ -1377,7 +303,6 @@ class TORCH_CUDA_CU_API Scope { void push_back(Expr* e) { exprs_.push_back(e); - e->setScope(this); } // Erase expr at pos @@ -1425,7 +350,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! //! TODO: cleaner way to set options? ForLoop( - Passkey passkey, + IrBuilderPasskey passkey, IterDomain* iter_domain, Val* index, Val* start, @@ -1435,17 +360,9 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { Val* vectorize_shift, bool unroll_required); - ForLoop(Passkey passkey, IterDomain* iter_domain); - - ForLoop(Passkey passkey, const ForLoop* other); + ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain); - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + ForLoop(IrBuilderPasskey passkey, const ForLoop* other); Val* index() const { return index_; @@ -1465,6 +382,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { return iter_domain_; } + // TODO: Return pointer instead of reference to be more consistent Scope& body() { return body_; } @@ -1524,15 +442,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: - explicit IfThenElse(Passkey passkey, Predicate* cond); - - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } + explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); Scope& thenBody() { return then_body_; @@ -1567,16 +477,8 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { //! reduction and sync buffers. class TORCH_CUDA_CU_API GridReduction final : public Expr { public: - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - GridReduction( - Passkey passkey, + IrBuilderPasskey passkey, ReductionOp* reduction_op, Allocate* reduction_buffer, Allocate* sync_buffer); @@ -1620,23 +522,11 @@ class TORCH_CUDA_CU_API GridReduction final : public Expr { //! broadcast and sync buffers. class TORCH_CUDA_CU_API GridBroadcast final : public Expr { public: - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - GridBroadcast( - Passkey passkey, + IrBuilderPasskey passkey, BroadcastOp* broadcast_op, Allocate* broadcast_buffer, - Allocate* sync_buffer) - : Expr(passkey), - broadcast_op_(broadcast_op), - broadcast_buffer_(broadcast_buffer), - sync_buffer_(sync_buffer){}; + Allocate* sync_buffer); BroadcastOp* broadcast_op() const { return broadcast_op_; @@ -1665,16 +555,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { //! reduction and sync buffers. class TORCH_CUDA_CU_API GridWelford final : public Expr { public: - void accept(IrVisitor* visitor) const override { - visitor->visit(this); - } - - void accept(MutableIrVisitor* visitor) override { - visitor->visit(this); - } - GridWelford( - Passkey passkey, + IrBuilderPasskey passkey, WelfordOp* welford_op, Allocate* var_buffer, Allocate* avg_buffer, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h b/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h deleted file mode 100644 index 17a095baf12..00000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_builder.h +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { -namespace kir { - -//! Kernel IR builder interface -//! -//! The only way to create new Kernel IR nodes is through the -//! kir::IrBuilder interface. An IrBuilder instance is attached to a -//! particular Kernel instance and it provides methods for creating -//! single nodes (kir::IrBuilder::create()) or basic composite expressions -//! (ex. kir::IrBuilder::addExpr()). -//! -//! If the Kernel object is readily available, an IrBuilder can be "wrapped" -//! around it directly: -//! -//! kir::IrBuilder ir_builder(kernel); -//! -//! During lowering, another option is to create an IrBuilder for the -//! kernel that is being created: -//! -//! kir::IrBuilder ir_builder(GpuLower::current()->kernel()); -//! -//! Once we have an IR builder instance, creating nodes looks like: -//! -//! auto new_node = ir_builder.create(1)); -//! auto result = ir_builder.mulExpr(lhs, rhs); -//! -class TORCH_CUDA_CU_API IrBuilder { - public: - explicit IrBuilder(Kernel* kernel) : kernel_(kernel) {} - - //! Allocate a new Kernel IR node, forwarding the arguments - //! to the appropriate constructor - template - T* create(Args&&... args) { - const kir::Passkey passkey(kernel_); - const auto node = new T(passkey, std::forward(args)...); - kernel_->registerIrNode(passkey, std::unique_ptr(node)); - return node; - } - - // Unary operations - Val* negExpr(Val* val); - Val* notExpr(Val* val); - Val* setExpr(Val* val); - Val* setExprNamedScalar(const std::string& name, Val* val); - Val* addressExprNamedScalar(const std::string& name, Val* val); - - // Binary operations - Val* andExpr(Val* lhs, Val* rhs); - Val* eqExpr(Val* lhs, Val* rhs); - Val* gtExpr(Val* lhs, Val* rhs); - Val* ltExpr(Val* lhs, Val* rhs); - Val* leExpr(Val* lhs, Val* rhs); - Val* geExpr(Val* lhs, Val* rhs); - Val* addExpr(Val* lhs, Val* rhs); - Val* subExpr(Val* lhs, Val* rhs); - Val* mulExpr(Val* lhs, Val* rhs); - Val* divExpr(Val* lhs, Val* rhs); - Val* ceilDivExpr(Val* lhs, Val* rhs); - Val* modExpr(Val* lhs, Val* rhs); - Val* maxExpr(Val* lhs, Val* rhs); - Val* minExpr(Val* lhs, Val* rhs); - - // Ternary operations - Val* whereExpr(Val* pred, Val* lhs, Val* rhs); - - // Shortcuts for frequently used vals - Int* zeroVal(); - Int* oneVal(); - Bool* falseVal(); - Bool* trueVal(); - - NamedScalar* magicZeroVal(); - - private: - Val* newResult(DataType dtype); - Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs); - Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs); - - private: - // Non-owning pointer to the kernel to be modified - Kernel* kernel_ = nullptr; - // Frequently used constant vals - Int* zero_ = nullptr; - Int* one_ = nullptr; - Bool* false_ = nullptr; - Bool* true_ = nullptr; - - // Magic zero corresponds to runtime/helpers.cu magic_zero - NamedScalar* magic_zero_ = nullptr; -}; - -//! A wrapper builder with static expression simplification -//! -//! Example: -//! - addExpr(new Int(1), new Int(2)) -> Int(3) -//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo") -//! -//! Designed to be used to simplify predicate and index expressions in -//! generated code. Also, the shift validation may fail without -//! this simplification. -class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { - public: - explicit SimplifyingIrBuilder(Kernel* kernel) : IrBuilder(kernel) {} - - Val* negExpr(Val* val); - Val* notExpr(Val* val); - - Val* addExpr(Int* lhs, Int::ScalarType rhs); - Val* addExpr(Int* lhs, Int* rhs); - Val* addExpr(Val* lhs, Val* rhs); - Val* subExpr(Val* lhs, Val* rhs); - Val* andExpr(Val* lhs, Val* rhs); -}; - -} // namespace kir -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp new file mode 100644 index 00000000000..bfc4794e299 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp @@ -0,0 +1,180 @@ +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace kir { +std::vector IrVisitor::handle(const std::vector& exprs) { + exprs_ = std::vector(exprs); + for (auto expr : exprs) { + handle(expr); + } + return exprs_; +} + +void IrVisitor::handle(ForLoop* fl) { + for_loops_.push_back(fl); + scope_.push_back(&fl->body()); + auto body_exprs = std::vector(fl->body().exprs()); + for (auto expr : body_exprs) { + handle(expr); + } + scope_.pop_back(); + for_loops_.pop_back(); +} + +void IrVisitor::handle(IfThenElse* ite) { + scope_.push_back(&ite->thenBody()); + auto then_exprs = std::vector(ite->thenBody().exprs()); + for (auto expr : then_exprs) { + handle(expr); + } + scope_.pop_back(); + + scope_.push_back(&ite->elseBody()); + auto else_exprs = std::vector(ite->elseBody().exprs()); + for (auto expr : else_exprs) { + handle(expr); + } + scope_.pop_back(); +} + +std::vector ExprMutator::mutate(bool reverse_order) { + if (insertions_.empty() && replacements_.empty()) { + return exprs_; + } + + auto run_insertion = [&](MutationInformation info) { + if (info.scope == nullptr) { + // If reference is nullptr and there are no expressions, simply insert the + // expr + if (exprs_.empty() && info.reference == nullptr) { + exprs_.push_back(info.new_expr); + return; + } + auto pos_it = std::find(exprs_.begin(), exprs_.end(), info.reference); + TORCH_INTERNAL_ASSERT( + pos_it != exprs_.end(), + "Issue finding reference expression for insertion."); + if (info.mode == MutationMode::BEFORE) { + exprs_.insert(pos_it, info.new_expr); + } else { + exprs_.insert(pos_it + 1, info.new_expr); + } + } else { + // If reference is nullptr and there are no expressions, simply insert the + // expr + if (info.scope->exprs().empty() && info.reference == nullptr) { + info.scope->push_back(info.new_expr); + return; + } + if (info.mode == MutationMode::BEFORE) { + info.scope->insert_before(info.reference, info.new_expr); + } else { + info.scope->insert_after(info.reference, info.new_expr); + } + } + }; + + if (reverse_order) { + for (auto it = insertions_.rbegin(); it != insertions_.rend(); ++it) { + run_insertion(*it); + } + } else { + for (auto insertion_info : insertions_) { + run_insertion(insertion_info); + } + } + + for (auto replacement_info : replacements_) { + if (replacement_info.scope == nullptr) { + auto pos_it = + std::find(exprs_.begin(), exprs_.end(), replacement_info.reference); + TORCH_INTERNAL_ASSERT( + pos_it != exprs_.end(), + "Issue finding reference expression for replacement."); + exprs_.insert(pos_it, replacement_info.new_expr); + // iterator can be invalidated from insertion + pos_it = + std::find(exprs_.begin(), exprs_.end(), replacement_info.reference); + exprs_.erase(pos_it); + } else { + replacement_info.scope->insert_before( + replacement_info.reference, replacement_info.new_expr); + replacement_info.scope->erase(replacement_info.reference); + } + } + + insertions_.clear(); + replacements_.clear(); + + return exprs_; +} + +std::vector ExprMutator::traverseAndInsert( + const std::vector& exprs, + bool reverse_order) { + IrVisitor::handle(exprs); + return mutate(reverse_order); +} + +void ExprMutator::registerMutation( + Expr* reference, + Expr* new_expr, + Scope* scope, + MutationMode mode) { + MutationInformation mutation; + mutation.reference = reference; + mutation.new_expr = new_expr; + mutation.scope = scope; + mutation.mode = mode; + if (mode == MutationMode::BEFORE || mode == MutationMode::AFTER) { + insertions_.push_back(mutation); + } else { + replacements_.push_back(mutation); + } +} + +void ExprMutator::registerInsertBefore( + Expr* reference, + Expr* new_expr, + Scope* scope) { + registerMutation(reference, new_expr, scope, MutationMode::BEFORE); +} + +void ExprMutator::registerInsertAfter( + Expr* reference, + Expr* new_expr, + Scope* scope) { + registerMutation(reference, new_expr, scope, MutationMode::AFTER); +} + +void ExprMutator::registerReplace( + Expr* reference, + Expr* new_expr, + Scope* scope) { + registerMutation(reference, new_expr, scope, MutationMode::REPLACE); +} + +void ExprMutator::registerInsertBefore(Expr* reference, Expr* new_expr) { + Scope* scope = scope_.empty() ? nullptr : scope_.back(); + registerInsertBefore(reference, new_expr, scope); +} + +void ExprMutator::registerInsertAfter(Expr* reference, Expr* new_expr) { + Scope* scope = scope_.empty() ? nullptr : scope_.back(); + registerInsertAfter(reference, new_expr, scope); +} + +void ExprMutator::registerReplace(Expr* reference, Expr* new_expr) { + Scope* scope = scope_.empty() ? nullptr : scope_.back(); + registerReplace(reference, new_expr, scope); +} + +} // namespace kir +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h new file mode 100644 index 00000000000..2140498af14 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h @@ -0,0 +1,118 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class Expr; + +namespace kir { +class Predicate; +class TensorIndex; +class ForLoop; +class IfThenElse; +class Scope; + +// Base visitor class that visits all nodes in provided vector. +// +// Includes visiting through scopes like IfThenElse and ForLoop, and tracks +// them in scopes_ and for_loops_. +// +// Makes a copy of exprs at exprs_ which could be used to modify and return. +// +// When traversing through ITE/FLs it will use a copy +// of the provided expressions to make it safe to insert/delete nodes. +// +// Provides a simple base class to inherit from for typical lowering passes on +// Expr list +class TORCH_CUDA_CU_API IrVisitor : public OptOutDispatch { + public: + std::vector handle(const std::vector& expr); + + protected: + using OptOutDispatch::handle; + + virtual void handle(ForLoop*) override; + virtual void handle(IfThenElse*) override; + + protected: + std::vector for_loops_; + std::vector scope_; + std::vector exprs_; +}; + +// Base Expr Mutator class that visits all nodes with IrVisitor, and then +// inserts new expressions or replaces expressions based on insertion/replace +// maps provided. These replacement maps are expected to accumulate during an +// initial traversal, then runs an insertion based on them after the overloaded +// traversal. +// +// Order of mutations may be important, mutations are ordered according to the +// following rules: +// Before/After insertions are ordered as registered when reverse_order == +// false, +// +// Before/After insertions are in reverse order as registered when +// reverse_order == true, +// +// Before/After insertions are done before Expr replacements, so reference for +// insertions must be on pre-replaced Exprs +// +// To place in a scope that is empty, simply provide a nullptr reference +// Since insertions are done in order, it's possible to insert an expression in +// an empty scope, and then use that inserted scope as a reference for +// subsequent mutations. +class ExprMutator : public IrVisitor { + protected: + std::vector traverseAndInsert( + const std::vector& expr, + bool reverse_order = false); + + std::vector mutate(bool reverse_order = false); + + using IrVisitor::handle; + // Registration function which *don't* need to be called "in place" during + // visiting. + void registerInsertBefore(Expr* reference, Expr* new_expr, Scope* scope); + void registerInsertAfter(Expr* reference, Expr* new_expr, Scope* scope); + void registerReplace(Expr* reference, Expr* new_expr, Scope* scope); + + // Registration function which need to be called "in place" during visiting. + // I.E. + // if you want to insert before/after or replace an Expr, you must register + // when in handle(Expr*) of that expr. + void registerInsertBefore(Expr* reference, Expr* new_expr); + void registerInsertAfter(Expr* reference, Expr* new_expr); + void registerReplace(Expr* reference, Expr* new_expr); + + private: + enum class MutationMode { BEFORE, AFTER, REPLACE }; + + void registerMutation( + Expr* ref, + Expr* new_expr, + Scope* scope, + MutationMode mode); + + struct MutationInformation { + Expr* reference = nullptr; + Expr* new_expr = nullptr; + Scope* scope = nullptr; + MutationMode mode = MutationMode::BEFORE; + }; + + // Track insertions as they're registered + std::vector insertions_; + + // Track replacements as they're registered + std::vector replacements_; +}; + +} // namespace kir +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp deleted file mode 100644 index e00da31423c..00000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp +++ /dev/null @@ -1,451 +0,0 @@ -#include -#include - -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { -namespace kir { - -namespace { - -const char* boolLiteral(bool value) { - return value ? "true" : "false"; -} - -std::string varName(const kir::Val* val, const char* prefix) { - std::stringstream value_name; - if (val == nullptr) { - value_name << "$nullptr"; - } else if (val->name() != kInvalidStmName) { - value_name << prefix << val->name(); - } else { - value_name << "k" << prefix << val->id(); - } - return value_name.str(); -} - -} // namespace - -void IrPrinter::printNode(const kir::Node* node) { - os_ << gen(node, true); -} - -void IrPrinter::printKernel(const Kernel* kernel) { - TORCH_CHECK(kernel != nullptr); - - // kernel declaration - os_ << "\nKERNEL ("; - for (auto in : kernel->inputs()) { - os_ << gen(in); - if (in != kernel->inputs().back()) { - os_ << ", "; - } - } - os_ << ") -> ("; - for (auto out : kernel->outputs()) { - os_ << gen(out); - if (out != kernel->outputs().back()) { - os_ << ", "; - } - } - os_ << ") :\n"; - - // kernel body - startBlock(); - for (auto expr : kernel->topLevelExprs()) { - os_ << gen(expr, true); - } - endBlock(); - os_ << "END.\n\n"; -} - -std::ostream& IrPrinter::indent() { - for (const auto i : c10::irange(indent_level_)) { - (void)i; // Suppress unused variable warning - ir_str_ << kTab; - } - ir_str_ << margin_; - return ir_str_; -} - -std::string IrPrinter::gen(const kir::Node* node, bool top_level) { - if (node == nullptr) { - return "$nullptr"; - } - - // If we're generatign a top level statement we expect to start - // with an empty set of uses - TORCH_INTERNAL_ASSERT(!implicit_definition_ || uses_.empty() || !top_level); - - // Mark the node as generated - visited_.insert(node); - - // Generate the node itself - std::stringstream node_str; - std::swap(node_str, ir_str_); - node->accept(this); - std::swap(node_str, ir_str_); - - if (!implicit_definition_) { - return node_str.str(); - } - - if (top_level) { - // Implicitly mark top level nodes as used, so we - // get their definitions printed (useful for debugging) - if (auto val = dynamic_cast(node)) { - uses_.insert(val); - } - - // Make a copy of the node uses (and reset global state) - const auto node_uses = uses_; - uses_.clear(); - - std::stringstream top_level_str; - - // Hoist implicit definitions - for (auto use : node_uses) { - const auto def = use->definition(); - if (def && visited_.find(def) == visited_.end()) { - margin_ = "~ "; - top_level_str << gen(def, true); - margin_ = ""; - } - } - - top_level_str << node_str.str(); - return top_level_str.str(); - } else { - return node_str.str(); - } -} - -std::string IrPrinter::use(const kir::Val* val) { - if (val != nullptr) { - uses_.insert(val); - } - return gen(val); -} - -void IrPrinter::startBlock() { - ++indent_level_; -} - -void IrPrinter::endBlock() { - TORCH_CHECK(indent_level_ > 0); - --indent_level_; -} - -void IrPrinter::handleBlock(const kir::Scope& scope) { - // Save the uses of the parent scope - decltype(uses_) outer_uses; - std::swap(uses_, outer_uses); - - startBlock(); - for (auto expr : scope.exprs()) { - ir_str_ << gen(expr, true); - } - endBlock(); - - // Restore parent's uses - std::swap(uses_, outer_uses); -} - -void IrPrinter::visit(const kir::Bool* node) { - if (node->isConst()) { - ir_str_ << boolLiteral(*node->value()); - } else { - ir_str_ << varName(node, "b"); - } -} - -void IrPrinter::visit(const kir::Double* node) { - if (node->isConst()) { - const int digits = std::numeric_limits::max_digits10; - ir_str_ << "double(" << std::setprecision(digits) << *node->value() << ")"; - } else { - ir_str_ << varName(node, "d"); - } -} - -void IrPrinter::visit(const kir::Int* node) { - if (node->isConst()) { - ir_str_ << *node->value(); - } else { - ir_str_ << varName(node, "i"); - } -} - -void IrPrinter::visit(const kir::NamedScalar* node) { - ir_str_ << node->name(); -} - -void IrPrinter::visit(const kir::Predicate* node) { - switch (node->predicate_type()) { - case PredicateType::Inline: { - ir_str_ << "Inline"; - break; - } - case PredicateType::Manual: { - ir_str_ << node->value(); - break; - } - case PredicateType::Misaligned: { - ir_str_ << "Misaligned"; - break; - } - case PredicateType::Padding: { - ir_str_ << "Padding"; - break; - } - case PredicateType::Shift: { - ir_str_ << "Shift"; - break; - } - case PredicateType::Unswitch: { - ir_str_ << "Unswitch"; - break; - } - case PredicateType::Vectorize: { - ir_str_ << "Vectorize"; - break; - } - default: - break; - } -} - -void IrPrinter::visit(const kir::TensorIndex* node) { - ir_str_ << gen(node->view()) << "["; - for (auto index : node->indices()) { - ir_str_ << use(index); - if (index != node->indices().back()) { - ir_str_ << ", "; - } - } - ir_str_ << "]"; -} - -void IrPrinter::visit(const kir::IterDomain* node) { - ir_str_ << varName(node, "id") << "["; - if (node->isRFactorProduct()) { - ir_str_ << "rfactor."; - } - ir_str_ << node->parallelType() << "." << node->iterType() << "(" - << use(node->start()) << " .. " << use(node->extent()) << ")]"; -} - -void IrPrinter::visit(const kir::TensorDomain*) { - // TODO(kir): print Tensor shapes? - ir_str_ << "kir::TensorDomain"; -} - -void IrPrinter::visit(const kir::TensorView* node) { - // TODO(kir): print memory type too? - ir_str_ << varName(node, "T"); -} - -void IrPrinter::visit(const kir::UnaryOp* node) { - indent() << gen(node->out()) << " = "; - - auto op_type = node->operation(); - - if (auto op = inline_op_str(op_type)) { - if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { - ir_str_ << stringifyBooleanOp(op_type) << gen(node->in()); - } else { - ir_str_ << *op << gen(node->in()); - } - } else { - if (op_type == UnaryOpType::Cast) { - const auto cast_str = - cast_func_str({node->in()->dtype(), node->out()->dtype()}); - ir_str_ << cast_str.value(); - } else { - ir_str_ << op_type; - if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) { - ir_str_ << "f"; - } - } - - if (op_type == UnaryOpType::RandLike) { - ir_str_ << "(RND"; - } else { - ir_str_ << "("; - ir_str_ << use(node->in()); - } - ir_str_ << ")"; - } - - ir_str_ << "\n"; -} - -void IrPrinter::visit(const kir::BinaryOp* node) { - indent() << gen(node->out()) << " = "; - - const auto op_type = node->operation(); - const auto lhs = use(node->lhs()); - const auto rhs = use(node->rhs()); - - if (auto op = inline_op_str(op_type)) { - ir_str_ << lhs << " "; - if (alsoBooleanOperator(op_type) && - node->out()->dtype() == DataType::Bool) { - ir_str_ << stringifyBooleanOp(op_type); - } else { - ir_str_ << *op; - } - ir_str_ << " " << rhs; - } else { - ir_str_ << op_type; - if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) { - ir_str_ << "f"; - } - ir_str_ << "(" << lhs << ", " << rhs << ")"; - } - - ir_str_ << "\n"; -} - -void IrPrinter::visit(const kir::TernaryOp* node) { - indent() << gen(node->out()) << " = " << node->operation() << "(" - << use(node->in1()) << ", " << use(node->in2()) << ", " - << use(node->in3()) << ")\n"; -} - -void IrPrinter::visit(const kir::ReductionOp* node) { - indent() << gen(node->out()) << " = " - << "REDUCTION(op='" << node->operation() << "'" - << ", in=" << use(node->in()) << ", init=" << use(node->init()) - << ", pred=" << use(node->predicate()) << ")\n"; -} - -void IrPrinter::visit(const kir::WelfordOp* node) { - indent() << gen(node->outVar()) << "," << gen(node->outAvg()) << "," - << gen(node->outN()) << " = " - << "Welford( inAvg=" << use(node->inAvg()); - if (!node->inN()->isOneInt()) { - indent() << " inVar=" << use(node->inVar()); - } - indent() << " inN=" << use(node->inN()); - if (!node->initN()->isZeroInt()) { - indent() << ", initVar=" << use(node->initVar()) - << " initAvg=" << use(node->initAvg()) - << " initN=" << use(node->initN()); - } - indent() << ", pred=" << use(node->predicate()) << ")\n"; -} - -void IrPrinter::visit(const kir::GridReduction* node) { - const auto* reduction_op = node->reduction_op(); - indent() << gen(reduction_op->out()) << " = " - << "GRID_REDUCTION(op='" << reduction_op->operation() << "'" - << ", in=" << use(reduction_op->in()) - << ", init=" << use(reduction_op->init()) - << ", pred=" << use(reduction_op->predicate()) << ")\n"; - indent() << kTab << kTab - << ".reduction_buffer=" << use(node->reduction_buffer()->buffer()) - << "\n"; - indent() << kTab << kTab - << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n"; - indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n"; -} - -void IrPrinter::visit(const kir::GridWelford* node) { - const auto* welford_op = node->welford_op(); - indent() << gen(welford_op->outVar()) << "," << gen(welford_op->outAvg()) - << "," << gen(welford_op->outN()) << " = " - << "GRID_WELFORD(" - << "inAvg=" << use(welford_op->inAvg()); - if (!welford_op->inN()->isOneInt()) { - indent() << ", inVar=" << use(welford_op->inVar()); - } - indent() << ", inN=" << use(welford_op->inN()); - if (!welford_op->initN()->isZeroInt()) { - indent() << ", initVar=" << use(welford_op->initVar()) - << " initAvg=" << use(welford_op->initAvg()) - << " initN=" << use(welford_op->initN()); - } - indent() << ", pred=" << use(welford_op->predicate()) << ")\n"; - indent() << kTab << kTab - << ".var_buffer=" << use(node->var_buffer()->buffer()) - << ".avg_buffer=" << use(node->avg_buffer()->buffer()) - << ".n_buffer=" << use(node->N_buffer()->buffer()) << "\n"; - indent() << kTab << kTab - << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n"; - indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n"; -} - -void IrPrinter::visit(const kir::BroadcastOp* node) { - indent() << gen(node->out()) << " = BROADCAST(" << use(node->in()) << ")\n"; -} - -void IrPrinter::visit(const kir::ForLoop* node) { - indent() << "FOR " << gen(node->index()) << " in " << gen(node->iter_domain()) - << ":\n"; - handleBlock(node->body()); -} - -void IrPrinter::visit(const kir::IfThenElse* node) { - indent() << "IF " << use(node->predicate()) << ":\n"; - handleBlock(node->thenBody()); - if (node->hasElse()) { - indent() << "ELSE:\n"; - handleBlock(node->elseBody()); - } -} - -void IrPrinter::visit(const kir::Allocate* node) { - indent() << gen(node->buffer()) << " = ALLOCATE(" - << "mem_type=" << node->memoryType() << ", " - << "size=" << use(node->size()) << ", " - << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n"; - if (node->alias() != nullptr) { - indent() << kTab << kTab << ".alias=" << gen(node->alias()->buffer()) - << "\n"; - } -} - -void IrPrinter::visit(const kir::Sync* node) { - indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync()) - << ")\n"; -} - -void IrPrinter::visit(const kir::InitMagicZero* node) { - indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n"; -} - -void IrPrinter::visit(const kir::UpdateMagicZero* node) { - indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; -} - -std::string toString(const kir::Node* stmt, bool implicit_definitions) { - std::stringstream ss; - IrPrinter ir_printer(ss, implicit_definitions); - ir_printer.printNode(stmt); - return ss.str(); -} - -std::string toString( - const std::vector& exprs, - bool implicit_definitions) { - std::stringstream ss; - IrPrinter ir_printer(ss, implicit_definitions); - for (auto expr : exprs) { - ir_printer.printNode(expr); - } - return ss.str(); -} - -} // namespace kir -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h b/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h deleted file mode 100644 index 115901a031a..00000000000 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_printer.h +++ /dev/null @@ -1,129 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include -#include -#include -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { -namespace kir { - -//! Define pretty printing functions for Kernel IR nodes -//! -//! This class is intended for debug printing, so it attempts -//! to handle invalid IR states as much as possible. -//! -//! implicit_definition_ = true will recurisvely print the definition of all -//! inputs to an expression if they haven't been printed. -class TORCH_CUDA_CU_API IrPrinter : private kir::IrVisitor { - static constexpr char const* kTab = " "; - - public: - //! Constructs a new IrPrinter which outputs to the specified stream - explicit IrPrinter(std::ostream& os, bool implicit_definition = true) - : os_(os), implicit_definition_(implicit_definition) {} - - //! Print a single Kernel IR node - void printNode(const kir::Node* node); - - //! Print a complete Kernel definition - void printKernel(const Kernel* kernel); - - private: - // Generates a string representation of an IR node - // - // If `top_level` is true, all the value uses are tracked and - // their definitions are implicitly printed before the node itself - // - std::string gen(const kir::Node* node, bool top_level = false); - - // Generate a string representation of an used value - // (this helps automatically tracking the value uses) - std::string use(const kir::Val* val); - - std::ostream& indent(); - - void startBlock(); - void endBlock(); - void handleBlock(const kir::Scope& scope); - - void visit(const kir::Bool*) final; - void visit(const kir::Double*) final; - void visit(const kir::Int*) final; - void visit(const kir::NamedScalar*) final; - void visit(const kir::Predicate*) final; - - void visit(const kir::TensorIndex*) final; - void visit(const kir::IterDomain*) final; - void visit(const kir::TensorDomain*) final; - void visit(const kir::TensorView*) final; - - void visit(const kir::UnaryOp*) final; - void visit(const kir::BinaryOp*) final; - void visit(const kir::TernaryOp*) final; - void visit(const kir::ReductionOp*) final; - void visit(const kir::WelfordOp*) final; - void visit(const kir::BroadcastOp*) final; - - void visit(const kir::GridReduction*) final; - void visit(const kir::GridWelford*) final; - void visit(const kir::ForLoop*) final; - void visit(const kir::IfThenElse*) final; - void visit(const kir::Allocate*) final; - void visit(const kir::Sync*) final; - void visit(const kir::InitMagicZero*) final; - void visit(const kir::UpdateMagicZero*) final; - - private: - std::ostream& os_; - - // Current indentation level - int indent_level_ = 0; - - // Internal IR generation stream - std::stringstream ir_str_; - - // Tracks the set of nodes which have been printed - std::unordered_set visited_; - - // Optional left margin printed after the indentation - const char* margin_ = ""; - - // The set of values used by the current top-level IR node - std::unordered_set uses_; - - // If the definition of all inputs to an expression haven't been printed - // already implicit_definition_ = true will print them before printing the - // requested node. - bool implicit_definition_ = true; -}; - -//! Returns the string representation of a Kernel IR node. If the definition of -//! all inputs to an expression haven't been printed already -//! implicit_definition_ = true will print them before printing the requested -//! node. -TORCH_CUDA_CU_API std::string toString( - const kir::Node* stmt, - bool implicit_definitions = true); - -//! Returns the string representation of a vector of kir::Expr, convenient -//! debugm echanism during lowering. If the definition of all inputs to an -//! expression haven't been printed already implicit_definition_ = true will -//! print them before printing the requested node. -TORCH_CUDA_CU_API std::string toString( - const std::vector& exprs, - bool implicit_definitions = true); - -} // namespace kir -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 036eee58206..21eb6e02fb8 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -6,18 +6,19 @@ #include #include #include -#include #include #include +#include #include +#include #include #include #include #include #include #include +#include #include -#include #include #include #include @@ -33,152 +34,15 @@ namespace jit { namespace fuser { namespace cuda { -// TODO(kir): revisit this thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT namespace { -// Going to generate a map of tensor view root domain extents to reduce the -// number used during lowering. For example if we have: -// -// T2[i0, i1] = T1[i0, i1] + T2[i2, i3] -// -// We know it would be safe to use: -// -// T2[i0, i1] = T1[i0, i1] + T2[i0, i1] -// -// And that way we don't generate T2.size[0] and T2.size[1], instead we will -// reuse T1.size[0] and T1.size[1] -// This is important when doing CSE as T2 and T1 would otherwise look like -// they're using different values, even though we know they're the same -// -// There's some duplicate logic here that's in computeAt map, but it's not so -// concice there to pull out. May want to consider making this mapping its own -// class especially as it may be useful during scheduling. -std::unordered_map getSimplificationMap(Fusion* fusion) { - std::list> disjoint_root_sets; - std::unordered_map*> - id_to_disjoint_root_set; - - auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set]( - IterDomain* id0, IterDomain* id1) { - if (id0->isBroadcast() || id1->isBroadcast()) { - return; - } - - auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0); - auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1); - bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end(); - bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end(); - - if (set_0_found && set_1_found) { - if (disjoint_set_0_it->second == disjoint_set_1_it->second) { - return; - } - // merge second disjoint set into first - auto* set_0 = disjoint_set_0_it->second; - auto* set_1 = disjoint_set_1_it->second; - for (auto id : *set_1) { - set_0->emplace(id); - id_to_disjoint_root_set[id] = set_0; - } - // remove second set from disjoint_root_sets - disjoint_root_sets.erase(std::find( - disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1)); - } else if (set_0_found || set_1_found) { - auto existing_set = - set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second; - auto to_add_id = set_0_found ? id1 : id0; - existing_set->emplace(to_add_id); - id_to_disjoint_root_set[to_add_id] = existing_set; - // add entry into existing set - } else { - // create new set entry - disjoint_root_sets.emplace_back(std::unordered_set()); - auto* new_set = &disjoint_root_sets.back(); - new_set->emplace(id0); - new_set->emplace(id1); - id_to_disjoint_root_set[id0] = new_set; - id_to_disjoint_root_set[id1] = new_set; - } - }; - - auto fusion_vals = fusion->usedMathVals(); - for (auto producer_tv : ir_utils::filterByType(fusion_vals)) { - auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); - for (auto consumer_tv : consumer_tvs) { - auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); - auto c2p_root_map = pairwise_map.mapConsumerToProducer( - consumer_tv->domain(), producer_tv->domain()); - for (auto entry : c2p_root_map) { - auto c_id = entry.first; - auto p_id = entry.second; - map_root_ids(p_id, c_id); - } - } - } - - // Map each set to an input ID (if it exists) that has the smallest ->name() - // entry value - std::unordered_map*, IterDomain*> - set_to_input_id; - - // Loop over the root domains, of the inputs to the fusion. Pick an input ID - // to use as the representative ID of the collected sets. Only consider inputs - // as those are the ones that map to values like "T0.size[1]". They are he - // ID's that propagated their extents into the problem. We could also check - // the outputs as we do have C++ examples of using output dimensions for the - // problem size instead of inputs. However, we don't do anything where we can - // translate to those kinds of kernels integrated into PyTorch. - for (auto input_tv : ir_utils::filterByType(fusion->inputs())) { - for (auto id : - TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) { - auto id_set_it = id_to_disjoint_root_set.find(id); - if (id_set_it == id_to_disjoint_root_set.end()) { - continue; - } - auto* id_set = id_set_it->second; - if (set_to_input_id.find(id_set) == set_to_input_id.end()) { - set_to_input_id[id_set] = id; - } else { - auto input_id_of_set = set_to_input_id.at(id_set); - // Swap id's if new name is less than previously set - bool swap_ids = id->name() < input_id_of_set->name(); - // If new id is a const scalar but previously was'nt use the const - // scalar - swap_ids = swap_ids || - (id->extent()->isConstScalar() && - !input_id_of_set->extent()->isConstScalar()); - // If previous scalar was const and new isn't, don't swap - swap_ids = swap_ids && - !(input_id_of_set->extent()->isConstScalar() && - !id->extent()->isConstScalar()); - - if (swap_ids) { - set_to_input_id[id_set] = id; - } - } - } - } - - // Finally make map from ID extents to the representitive ID extent. - std::unordered_map extent_to_min_input_id_extent; - for (auto entry : set_to_input_id) { - auto* set = entry.first; - auto input_id = entry.second; - for (auto id : *set) { - extent_to_min_input_id_extent[id->extent()] = input_id->extent(); - } - } - return extent_to_min_input_id_extent; -} - -class KIRCleaner : public kir::MutableIrVisitor { +class KIRCleaner : public OptOutDispatch { public: //! Remove nop IR nodes - static std::vector cleanUp( - const std::vector& loop_nests) { + static std::vector cleanUp(const std::vector& loop_nests) { KIRCleaner cleaner; - std::vector out_loop_nests; + std::vector out_loop_nests; for (auto loop_nest : loop_nests) { cleaner.handle(loop_nest); // No need to keep the loop nest if it's determined to be nop @@ -190,16 +54,17 @@ class KIRCleaner : public kir::MutableIrVisitor { } private: - void handle(kir::Expr* expr) { + using OptOutDispatch::handle; + void handle(Expr* expr) final { if (expr->isA() || expr->isA()) { - expr->accept(this); + OptOutDispatch::handle(expr); } else { // Any non-scoping expr is not considered nop is_nop_ = false; } } - void visit(kir::ForLoop* fl) final { + void handle(kir::ForLoop* fl) final { auto exprs = fl->body().exprs(); fl->body().clear(); for (auto expr : exprs) { @@ -213,7 +78,7 @@ class KIRCleaner : public kir::MutableIrVisitor { is_nop_ = fl->body().empty(); } - void visit(kir::IfThenElse* ite) final { + void handle(kir::IfThenElse* ite) final { const auto conditional = ite->predicate()->value(); // Visit the then block @@ -248,9 +113,8 @@ class KIRCleaner : public kir::MutableIrVisitor { // conditional and move the exprs in the else block to the then // block. if (then_nop && !else_nop) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - kir::Bool* pred = ite->predicate()->value(); - kir::Bool* not_pred = ir_builder.notExpr(pred)->as(); + Bool* pred = ite->predicate()->value(); + Bool* not_pred = SimplifyingIrBuilder::notExpr(pred)->as(); ite->predicate()->setValue(not_pred); for (auto expr : ite->elseBody().exprs()) { ite->thenBody().push_back(expr); @@ -269,84 +133,6 @@ class KIRCleaner : public kir::MutableIrVisitor { } // namespace -void GpuLower::replaceSymbolicSizes() { - FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes"); - - kir::IrBuilder ir_builder(kernel()); - - // Grab inputs and outputs - std::vector inputs_and_outputs; - for (auto val : fusion_->inputs()) { - if (ir_utils::isTV(val)) { - inputs_and_outputs.push_back(val->as()); - } - } - // Symbolic size is necessary for outputs if there are no inputs. - // Otherwise infer output sizes from the inputs via expression evaluation. - if (fusion_->inputs().empty()) { - for (auto val : fusion_->outputs()) { - if (ir_utils::isTV(val)) { - inputs_and_outputs.push_back(val->as()); - } - } - } - - // Generate map for all tensorview root domain values to map them to symbolic - // values. i.e. T0->getRootDomain()[0] would map to a named scalar - // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir. - for (TensorView* tv : inputs_and_outputs) { - // Replace the domain with one based on Ti.size[j] - const std::vector& root_td = tv->getRootDomain(); - - size_t dim = 0; - for (auto id : root_td) { - const Val* orig_size = id->extent(); - - // Output sizes could have reduction axes, which isn't what gets output. - // NOLINTNEXTLINE(bugprone-branch-clone) - if (id->isReduction() || - (id->getIterType() == IterType::BroadcastWithoutStride)) { - continue; - } else if ( - id->isRFactorProduct() || - // NOLINTNEXTLINE(bugprone-branch-clone) - (id->getIterType() == IterType::BroadcastWithStride) || - orig_size->isConstScalar()) { - dim++; - continue; - } - - // TODO(kir): consider a different implementation which doesn't - // hijack the kir_val_map_ - // Currently turn off this part for inputs of segmented fusion, - // since FusionKernelRuntime will provide these as integer inputs - if (kir_val_map_.find(orig_size) == kir_val_map_.end() && - !orig_size->isFusionInput() && !orig_size->isConstScalar()) { - std::stringstream ss; - ss << "T" << tv->name() << ".size[" << dim++ << "]"; - kir_val_map_[orig_size] = ir_builder.create( - ss.str(), orig_size->getDataType().value()); - } else { - dim++; - } - } - } - - // Use a minimal number of sizes from provided tensors. - auto extent_simplification_map = getSimplificationMap(fusion_); - for (auto extent_entry : extent_simplification_map) { - auto orig_extent = extent_entry.first; - auto simplified_extent = extent_entry.second; - if (kir_val_map_.count(orig_extent)) { - if (kir_val_map_.count(simplified_extent)) { - kir_val_map_[orig_extent] = kir_val_map_[simplified_extent]; - } else { - kir_val_map_[orig_extent] = lowerValue(simplified_extent); - } - } - } -} - void GpuLower::collectPaddedParallelDims() { ExpressionEvaluator ee(fusion_); bool can_be_single_warp = true; @@ -398,14 +184,12 @@ void GpuLower::collectPaddedParallelDims() { } } -void GpuLower::lower() { +void GpuLower::lower(Fusion* fusion) { FUSER_PERF_SCOPE("GpuLower::lower"); - - TORCH_INTERNAL_ASSERT(fusion_ != nullptr); + TORCH_INTERNAL_ASSERT(fusion != nullptr); TORCH_INTERNAL_ASSERT( active_gpu_lower == nullptr, "Nested lowering passes are not supported"); - // TODO(kir): revisit this struct LowerGuard { LowerGuard(GpuLower* gpu_lower) { active_gpu_lower = gpu_lower; @@ -414,17 +198,21 @@ void GpuLower::lower() { active_gpu_lower = nullptr; } } lower_guard(this); + // Copy fusion into a new kernel for processing + kernel_ = std::make_unique(fusion); + // Alias the fusion kernel caries around as a view of itself. + fusion_ = kernel_.get(); FusionGuard fg(fusion_); - - // Start with a fresh kernel - kernel_ = std::make_unique(); - // prepare for lowering validateIr(fusion_); - replaceSymbolicSizes(); + collectPaddedParallelDims(); - trivial_reduction_info_.build(fusion_, this); + + replaceSymbolicSizes(fusion_); + + trivial_reduction_info_.build(fusion_); + trivialReductionReplacement(fusion_, trivialReductionInfo()); // In the future we may directly use this map, but for now it will propagate // and validate (to some extent) the parallelization strategy. @@ -447,9 +235,12 @@ void GpuLower::lower() { parallelDimensionMap().build(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) { - std::cout << parallelDimensionMap().toString(); + std::cout << "Parallel dimension map:" << std::endl; + std::cout << parallel_dimension_map_.toString() << std::endl; } + concretized_broadcast_domains_.build(fusion_); + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); @@ -469,61 +260,67 @@ void GpuLower::lower() { nonDivisibleSplitInfo().build(fusion_); - // Set the kernel inputs & outputs - for (auto input : fusion_->inputs()) { - kernel_->addInput(GpuLower::lowerValue(input)); - } - - for (auto output : fusion_->outputs()) { - kernel_->addOutput(GpuLower::lowerValue(output)); - } + doubleBufferInfo().build(fusion_); // Run our passes keeping the lowered expressions and forwarding // them // Reorder expressions for loop-nest generation respecting computeAt // relationships - auto sorted_exprs = reorderExprsForComputeAt(); + const auto exprs_sorted = reorderExprsForComputeAt(); // Generate loop-nests and place each expression at its // corresponding loop - const auto lowered_exprs = LoopNestGenerator::loweredExprs(sorted_exprs); + const auto exprs_lowered = LoopNestGenerator::loweredExprs(exprs_sorted); + + // Replace trivial reductions, Transpose, Shift, Gather, and View ops with + // unary ops since they're not separately processed in lowering. + const auto exprs_unary_replaced = unarySetOpInserter(exprs_lowered); // Insert allocations - const auto alloced_exprs = insertAllocations(lowered_exprs); + const auto exprs_alloced = insertAllocations(exprs_unary_replaced); // Insert read after write smem syncs - const auto raw_sync_exprs = insertRawThreadSynchronization(alloced_exprs); + const auto exprs_raw_sync = insertRawThreadSynchronization(exprs_alloced); // Reuse memory locations - const auto reuse_mem_exprs = reuseMemoryAllocations(raw_sync_exprs); + const auto exprs_reuse_mem = reuseMemoryAllocations(exprs_raw_sync); - // Inserts predicates after this, need to be careful in later passes when - // inserting in loop nest structure as insertions could be on if then else - // instead of directly on a for loop - const auto unrolled_loops = UnrollPass::runPass(fusion_, reuse_mem_exprs); + // Insert SyncThreads at end of for-loop to avoid WAR race condition + const auto exprs_war_sync = insertWarThreadSynchronization(exprs_reuse_mem); - const auto unrolled_mv_loops = - processMisalignedVectorization(fusion_, unrolled_loops); + const auto exprs_double_buffered = DoubleBufferPass::run(exprs_war_sync); - // Insert SyncThreads at end of for-loop to avoid WAR race condition - const auto war_sync_exprs = insertWarThreadSynchronization(unrolled_mv_loops); + // This pass inserts predicates as well as branches in the code. Up until now + // the code is explicitly single shot for loop based. Need to be careful in + // later passes when doing any kind of insertions in loop nest structure as + // insertions could be on if then or else instead of directly on a for loop. + const auto exprs_unrolled_loops = + UnrollPass::runPass(fusion_, exprs_double_buffered); - const auto indexed_loops = IndexLowering::getIndexedExprs(war_sync_exprs); + const auto exprs_unrolled_mv_loops = + processMisalignedVectorization(exprs_unrolled_loops); - const auto exprs_with_fused_broadcast = fuseWarpReduce(indexed_loops); + const auto exprs_indexed_loops = + IndexLowering::getIndexedExprs(exprs_unrolled_mv_loops); - const auto conditional_loops = - generateConditionalFromPredicate(fusion_, exprs_with_fused_broadcast); + // TODO: It seems this type of optimization would be far easier to implement + // on fusion ir than kernel ir. We should likely refactor this to at least run + // before allocation insertion. + const auto exprs_with_fused_broadcast = fuseWarpReduce(exprs_indexed_loops); + + const auto exprs_conditional_loops = + generateConditionalFromPredicate(exprs_with_fused_broadcast); // Insert fake zero updates to make sure nvrtc doesn't blow out register use // on index and predicate reuse - const auto register_adjusted = insertMagicZero(conditional_loops); + const auto exprs_register_adjusted = insertMagicZero(exprs_conditional_loops); - const auto cleaned_up_loops = KIRCleaner::cleanUp(register_adjusted); + const auto exprs_cleaned_up_loops = + KIRCleaner::cleanUp(exprs_register_adjusted); // We now have the lowered expressions, finalize the kernel IR - kernel_->finalize(cleaned_up_loops); + kernel_->finalize(exprs_cleaned_up_loops); } kir::Kernel* GpuLower::kernel() const { @@ -531,213 +328,9 @@ kir::Kernel* GpuLower::kernel() const { return kernel_.get(); } -// Maps Fusion IR nodes to the Kernel IR counterparts -class GpuLower::KernelIrMapper : private OptInConstDispatch { - public: - explicit KernelIrMapper(GpuLower* gpu_lower) - : gpu_lower_(gpu_lower), ir_builder_(gpu_lower->kernel()) {} - - kir::Val* lowerValue(const Val* value) { - const auto it = gpu_lower_->kir_val_map_.find(value); - if (it != gpu_lower_->kir_val_map_.end()) { - return it->second; - } else { - handle(value); - const auto kir_value = gpu_lower_->kir_val_map_[value]; - TORCH_CHECK(kir_value != nullptr); - - // Lower the value definition, if any - if (value->isScalar()) { - if (auto def = value->definition()) { - const auto kir_def = lowerExpr(def); - TORCH_INTERNAL_ASSERT(kir_value->definition() == kir_def); - } - } - - return kir_value; - } - } - - kir::Expr* lowerExpr(const Expr* expr) { - const auto it = gpu_lower_->kir_expr_map_.find(expr); - if (it != gpu_lower_->kir_expr_map_.end()) { - return it->second; - } else { - handle(expr); - const auto lowered_node = gpu_lower_->kir_expr_map_[expr]; - TORCH_CHECK(lowered_node != nullptr); - return lowered_node; - } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - } - - private: - void handle(const Statement* node) final { - OptInConstDispatch::handle(node); - } - - void handle(const Val* node) final { - OptInConstDispatch::handle(node); - } - - void handle(const Expr* node) final { - OptInConstDispatch::handle(node); - } - - void handle(const TensorDomain* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const IterDomain* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const TensorView* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const Bool* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const Double* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const Int* node) final { - const auto lowered_node = ir_builder_.create(node); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const NamedScalar* node) final { - const auto lowered_node = ir_builder_.create( - node->name(), node->getDataType().value()); - TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second); - } - - void handle(const UnaryOp* node) final { - const auto lowered_node = ir_builder_.create( - node->getUnaryOpType(), - lowerValue(node->out()), - lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const BinaryOp* node) final { - const auto lowered_node = ir_builder_.create( - node->getBinaryOpType(), - lowerValue(node->out()), - lowerValue(node->lhs()), - lowerValue(node->rhs())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const TernaryOp* node) final { - const auto lowered_node = ir_builder_.create( - node->getTernaryOpType(), - lowerValue(node->out()), - lowerValue(node->in1()), - lowerValue(node->in2()), - lowerValue(node->in3())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const ReductionOp* node) final { - auto out_tv = node->out()->as(); - // If trivial reduction operation lower to set operation. - if (std::all_of( - out_tv->domain()->domain().begin(), - out_tv->domain()->domain().end(), - [&](IterDomain* id) { - // If id is a reduction axis, is it a trivial reduction? - if (id->isReduction()) { - return gpu_lower_->trivialReductionInfo().isDerived(id); - } else { - return true; - } - })) { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK( - gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - return; - } - - const auto lowered_node = ir_builder_.create( - node->getReductionOpType(), - lowerValue(node->init()), - lowerValue(node->out()), - lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const WelfordOp* node) final { - auto lowerOptional = [&](Val* v) { return v ? lowerValue(v) : nullptr; }; - const auto lowered_node = ir_builder_.create( - lowerValue(node->outVar()), - lowerValue(node->outAvg()), - lowerValue(node->outN()), - lowerValue(node->initVar()), - lowerValue(node->initAvg()), - lowerValue(node->initN()), - lowerOptional(node->inVar()), - lowerValue(node->inAvg()), - lowerValue(node->inN())); - - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const BroadcastOp* node) final { - const auto lowered_node = ir_builder_.create( - lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const TransposeOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const ShiftOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const GatherOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - void handle(const ViewOp* node) final { - const auto lowered_node = ir_builder_.create( - UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in())); - TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); - } - - private: - GpuLower* gpu_lower_ = nullptr; - kir::IrBuilder ir_builder_; -}; - -kir::Val* GpuLower::lowerValue(const Val* val) { - KernelIrMapper kir_mapper(this); - return kir_mapper.lowerValue(val); -} - -kir::Expr* GpuLower::lowerExpr(const Expr* expr) { - KernelIrMapper kir_mapper(this); - return kir_mapper.lowerExpr(expr); -} - GpuLower* GpuLower::current() { + TORCH_INTERNAL_ASSERT( + active_gpu_lower != nullptr, "No active GpuLower available"); return active_gpu_lower; } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index b807bb4d480..b97c6ac1837 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -1,14 +1,17 @@ #pragma once -#include +#include #include #include #include #include #include +#include #include #include +#include +#include #include #include #include @@ -29,29 +32,27 @@ namespace cuda { // container for this information that we can reuse. Would be nice to generate // such a structure and propagate it through lowering. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class TORCH_CUDA_CU_API GpuLower { +class TORCH_CUDA_CU_API GpuLower : public NonCopyable { class KernelIrMapper; public: - GpuLower() = default; + GpuLower() = delete; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit GpuLower(Fusion* fusion) : fusion_(fusion) { - lower(); + explicit GpuLower(Fusion* fusion) { + lower(fusion); } kir::Kernel* kernel() const; - //! Converts a Fusion IR value into the Kernel IR equivalent - kir::Val* lowerValue(const Val* val); - - //! Converts a Fusion IR expression into the Kernel IR equivalent - kir::Expr* lowerExpr(const Expr* expr); - //! Returns the currently active lowering object //! (or nullptr if no lowering is in progress) static GpuLower* current(); + ConcretizedBroadcastDomains& concretizedBroadcastDomains() { + return concretized_broadcast_domains_; + } + const ThreadPredicateMap& threadPredMap() const { return thread_pred_map_; } @@ -68,7 +69,7 @@ class TORCH_CUDA_CU_API GpuLower { return ca_parallel_map_; } - const auto& trivialReductionInfo() const { + const TrivialReductionInfo& trivialReductionInfo() const { return trivial_reduction_info_; } @@ -120,16 +121,12 @@ class TORCH_CUDA_CU_API GpuLower { return non_divisible_split_info_; } - private: - void lower(); + DoubleBufferInfo& doubleBufferInfo() { + return double_buffer_info_; + } - // TensorViews are all based on symbolic sizes. When we first initialize them - // we don't know if they're inputs or outputs which would mean that they have - // runtime shapes. Intermediate tensors (those not going to global memory) do - // not have this information. Since we need to have the correct information in - // the kernel being fetched for shapes, we want to replace input and output - // tensors to reference the runtime structure containing sizes. - void replaceSymbolicSizes(); + private: + void lower(Fusion* fusion); // Goes through the parallelized iterdomains of the used TVs and find // the parallel dimensions that need to be padded to a multiples of @@ -140,11 +137,8 @@ class TORCH_CUDA_CU_API GpuLower { // Lowered Kernel IR std::unique_ptr kernel_; - // Fusion IR node to Kernel IR node mapping - std::unordered_map kir_val_map_; - std::unordered_map kir_expr_map_; - // Some stateful information during lowering + ConcretizedBroadcastDomains concretized_broadcast_domains_; ThreadPredicateMap thread_pred_map_; PredicateElimination pred_elimination_; ComputeAtMap ca_loop_map_; @@ -157,6 +151,7 @@ class TORCH_CUDA_CU_API GpuLower { ParallelDimensionMap parallel_dimension_map_; PartialSplitMap partial_split_map_; NonDivisibleSplitInfo non_divisible_split_info_; + DoubleBufferInfo double_buffer_info_; Fusion* fusion_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 80e2e58c9cf..17a2db069d8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -1,10 +1,10 @@ #include #include +#include #include #include #include -#include #include #include @@ -22,40 +22,42 @@ namespace { //! Get string representation of Allocate size for symbolic comparison //! //! TODO: Some expr simplifications could also be helpful -class SymbolicSizePrinter : private kir::IrVisitor { +class SymbolicSizePrinter : private OptOutConstDispatch { public: static std::string printSize(const kir::Allocate* allocate) { SymbolicSizePrinter printer; - allocate->size()->accept(&printer); + printer.handle(allocate->size()); return printer.os_.str(); } private: - void visit(const kir::Int* node) final { + using OptOutConstDispatch::handle; + + void handle(const Int* node) final { if (auto def = node->definition()) { - def->accept(this); + OptOutConstDispatch::handle(def); } else if (node->isConst()) { os_ << *node->value(); } else { - os_ << "ki" << node->id(); + os_ << "ki" << node->name(); } } - void visit(const kir::NamedScalar* named_scalar) final { + void handle(const NamedScalar* named_scalar) final { os_ << "@" << named_scalar->name(); } - void visit(const kir::UnaryOp* unary_op) final { - os_ << unary_op->operation() << "("; - unary_op->in()->accept(this); + void handle(const UnaryOp* unary_op) final { + os_ << unary_op->getUnaryOpType() << "("; + OptOutConstDispatch::handle(unary_op); os_ << ")"; } - void visit(const kir::BinaryOp* binary_op) final { - os_ << binary_op->operation() << "("; - binary_op->lhs()->accept(this); + void handle(const BinaryOp* binary_op) final { + os_ << binary_op->getBinaryOpType() << "("; + OptOutConstDispatch::handle(binary_op->lhs()); os_ << ","; - binary_op->rhs()->accept(this); + OptOutConstDispatch::handle(binary_op->rhs()); os_ << ")"; } @@ -74,11 +76,11 @@ class BufferReuseDebugPrinter { DebugLineType line_type = DebugLineType::EXPR; }; - using DebugEntry = std::pair; + using DebugEntry = std::pair; using DebugEntryPtr = std::unique_ptr; public: - BufferReuseDebugPrinter() : ir_printer_(os_, false){}; + BufferReuseDebugPrinter() : ir_printer_(os_){}; std::string dumpDebugInfo() { os_.clear(); @@ -105,7 +107,7 @@ class BufferReuseDebugPrinter { private: friend class BufferUseDefInfo; - void pushBack(int lineno, kir::Expr* expr) { + void pushBack(int lineno, Expr* expr) { makeExprEntry(lineno, expr); } @@ -117,7 +119,7 @@ class BufferReuseDebugPrinter { makeScopeEntry(DebugLineType::END_BLOCK); } - void makeExprEntry(int lineno, kir::Expr* expr) { + void makeExprEntry(int lineno, Expr* expr) { auto debug_entry_ptr = std::make_unique(); debug_entry_ptr->first.lineno = lineno; debug_entry_ptr->second = expr; @@ -134,14 +136,14 @@ class BufferReuseDebugPrinter { debug_info_.emplace_back(std::move(debug_entry_ptr)); } - void handle(const kir::Expr* node) { + void handle(const Expr* node) { if (auto for_loop = dynamic_cast(node)) { handle(for_loop); } else if (auto ite = dynamic_cast(node)) { handle(ite); } else { indent(); - ir_printer_.printNode(node); + ir_printer_.handle(node); } if (auto alloc = dynamic_cast(node)) { printAllocInfo(alloc); @@ -151,9 +153,9 @@ class BufferReuseDebugPrinter { void handle(const kir::ForLoop* node) { indent(); os_ << "FOR "; - ir_printer_.printNode(node->index()); + ir_printer_.handle(node->index()); os_ << " in "; - ir_printer_.printNode(node->iter_domain()); + ir_printer_.handle(node->iter_domain()); os_ << ":\n"; } @@ -186,7 +188,7 @@ class BufferReuseDebugPrinter { private: std::stringstream os_; - kir::IrPrinter ir_printer_; + IrPrinter ir_printer_; int indent_level_ = 0; std::vector debug_info_; @@ -340,7 +342,7 @@ class BufferUseDefInfo { static constexpr long kRegisterSizeThreshold = 1; BufferUseDefInfo( - const std::vector& exprs, + const std::vector& exprs, BufferReuseDebugPrinter* debug_printer = nullptr) : debug_printer_(debug_printer) { if (debug_printer) { @@ -410,7 +412,7 @@ class BufferUseDefInfo { } private: - void handle(kir::Expr* expr) { + void handle(Expr* expr) { current_pos_++; if (debug_printer_) { debug_printer_->pushBack(current_pos_, expr); @@ -426,7 +428,7 @@ class BufferUseDefInfo { } } - void handleScope(const std::vector& exprs) { + void handleScope(const std::vector& exprs) { if (debug_printer_) { debug_printer_->pushScope(); } @@ -460,15 +462,15 @@ class BufferUseDefInfo { return; } - auto kir_tv = dynamic_cast(alloc->buffer()); - if (!kir_tv) { + auto tv = dynamic_cast(alloc->buffer()); + if (!tv) { return; } // Collect the allocate info data // Collect memory type, skip global buffers - auto mem_type = kir_tv->memoryType(); + auto mem_type = tv->getMemoryType(); if (mem_type != MemoryType::Local && mem_type != MemoryType::Shared) { return; } @@ -487,12 +489,12 @@ class BufferUseDefInfo { } } - auto data_type = kir_tv->dtype(); + auto data_type = tv->dtype(); auto size_print = SymbolicSizePrinter::printSize(alloc); // Make sure we don't have conflicting information on record TORCH_INTERNAL_ASSERT(!map_allocate_to_info_.count(alloc)); - TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(kir_tv->name())); + TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(tv->name())); // make AllocationUseDefInfo: auto alloc_info = makeUseDefInfo(); @@ -505,10 +507,10 @@ class BufferUseDefInfo { // record short cuts map_allocate_to_info_[alloc] = alloc_info; - map_tv_to_allocations_[kir_tv->name()] = alloc_info; + map_tv_to_allocations_[tv->name()] = alloc_info; } - void collectScopeUseDefInfo(const std::vector& exprs) { + void collectScopeUseDefInfo(const std::vector& exprs) { // Reset position pointer resetExprCounter(); TORCH_INTERNAL_ASSERT(global_scope_info_ != nullptr); @@ -516,14 +518,14 @@ class BufferUseDefInfo { handleScope(exprs); } - void collectScopeInfo(const std::vector& exprs) { + void collectScopeInfo(const std::vector& exprs) { // Reset position pointer resetExprCounter(); collectScopeInfoWithinLoop(exprs, nullptr); } void collectScopeInfoWithinLoop( - const std::vector& exprs, + const std::vector& exprs, kir::ForLoop* current_loop) { auto loop_info = makeScopeInfo(current_loop); for (auto expr : exprs) { @@ -584,22 +586,20 @@ class BufferUseDefInfo { // Iterate over the inputs and outputs of exprs and update // the liveness info of local buffers if applicaable. - void collectLivenessInfo(const kir::Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + void collectLivenessInfo(const Expr* expr) { + if (!ir_utils::isTvOp(expr)) { return; } - auto out_tv = expr->outputs()[0]->as(); - auto fuser_out_tv = out_tv->fuserTv(); + auto out_tv = expr->outputs()[0]->as(); // Collect all tv's that resolves broadcast in this // expr. The current analysis isn't enough to capture // their liveness range. - for (auto input_tv : - ir_utils::filterByType(expr->inputs())) { + for (auto input_tv : ir_utils::filterByType(expr->inputs())) { auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); if (maybe_alloc_info.has_value()) { - if (isSerialBroadcastResolution(input_tv->fuserTv(), fuser_out_tv)) { + if (isSerialBroadcastResolution(input_tv, out_tv)) { maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_); } else { // Disable inner alias info for this buffer, since line number based @@ -621,8 +621,7 @@ class BufferUseDefInfo { } } } - for (auto output_tv : - ir_utils::filterByType(expr->outputs())) { + for (auto output_tv : ir_utils::filterByType(expr->outputs())) { auto maybe_alloc_info = getMaybeAllocInfoFromTV(output_tv); if (maybe_alloc_info.has_value()) { maybe_alloc_info.value()->inner_live_interval->markWrite(current_pos_); @@ -675,8 +674,7 @@ class BufferUseDefInfo { return nullptr; } - c10::optional getMaybeAllocInfoFromTV( - kir::TensorView* tv) { + c10::optional getMaybeAllocInfoFromTV(TensorView* tv) { auto alloc_it = map_tv_to_allocations_.find(tv->name()); if (alloc_it == map_tv_to_allocations_.end()) { return c10::nullopt; @@ -810,11 +808,11 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { //! Reuse Allocation nodes via pointer aliasing class AllocateReuseModifier { public: - static void modify(const std::vector& exprs) { + static void modify(const std::vector& exprs) { AllocateReuseModifier modifier(exprs); } - static void debugPrint(const std::vector& exprs) { + static void debugPrint(const std::vector& exprs) { BufferReuseDebugPrinter debug_printer; AllocateReuseModifier modifier(exprs, &debug_printer); std::cout << debug_printer.dumpDebugInfo(); @@ -822,7 +820,7 @@ class AllocateReuseModifier { private: AllocateReuseModifier( - const std::vector& exprs, + const std::vector& exprs, BufferReuseDebugPrinter* debug_printer_ = nullptr) : buffer_info_(exprs, debug_printer_) { // Perform in-place sharing first and then outer liveness @@ -941,7 +939,7 @@ class AllocateReuseModifier { return false; } - void handle(kir::Expr* expr) { + void handle(Expr* expr) { if (auto ite = dynamic_cast(expr)) { handle(ite); } else if (auto for_loop = dynamic_cast(expr)) { @@ -961,7 +959,7 @@ class AllocateReuseModifier { "lower_alias_memory: IfThenElse before unrolling is not yet supported"); } - void handleScope(const std::vector& exprs) { + void handleScope(const std::vector& exprs) { current_visible_buffer_stack_.emplace_back( std::make_unique()); for (auto expr : exprs) { @@ -990,10 +988,8 @@ class AllocateReuseModifier { } // Assume inputs are TV allocations, which should have been checked // before reaching this point. - auto this_tv = - alloc_info->alloc_expr->buffer()->as()->fuserTv(); - auto reuse_tv = - to_reuse->alloc_expr->buffer()->as()->fuserTv(); + auto this_tv = alloc_info->alloc_expr->buffer()->as(); + auto reuse_tv = to_reuse->alloc_expr->buffer()->as(); // Check the values in between the two buffers. auto vals_between_this_and_reuse = @@ -1068,8 +1064,8 @@ class AllocateReuseModifier { } bool allocationDomainsIndexMapped( - std::vector& alloc_domains, - std::vector& reuse_domains) { + std::vector& alloc_domains, + std::vector& reuse_domains) { // Require that the allocated domains are exactly mapped. if (alloc_domains.size() != reuse_domains.size()) { return false; @@ -1099,7 +1095,7 @@ class AllocateReuseModifier { // Do we have a true pointwise op? // (ie. a TV op, excluding direct assignments and reductions) bool isPointwiseTvOp(const Expr* expr) { - if (ir_utils::isTVOp(expr)) { + if (ir_utils::isTvOp(expr)) { return expr->isA() || expr->isA() || expr->isA(); } @@ -1108,7 +1104,7 @@ class AllocateReuseModifier { // Utility to capture reduction ops bool isReductionTvOp(const Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } return expr->isA() || expr->isA(); @@ -1116,7 +1112,7 @@ class AllocateReuseModifier { // Utility to capture reduction ops bool isBroadcastTvOp(const Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } return expr->isA(); @@ -1138,8 +1134,7 @@ class AllocateReuseModifier { } // namespace -std::vector reuseMemoryAllocations( - const std::vector& exprs) { +std::vector reuseMemoryAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("reuseMemoryAllocations"); bool debug_print = isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo); if (debug_print) { diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.h b/torch/csrc/jit/codegen/cuda/lower_alias_memory.h index 26b33b6d5dc..0d144b9f2f4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.h +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -28,8 +28,7 @@ namespace cuda { //! is not used after this op: //! then alias output Allocate to input Allocate. //! -std::vector reuseMemoryAllocations( - const std::vector& exprs); +std::vector reuseMemoryAllocations(const std::vector& exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 2f70c275832..c03848ccff8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -1,10 +1,8 @@ -#include #include #include #include #include -#include -#include +#include #include #include @@ -17,8 +15,12 @@ namespace cuda { namespace { -class AllocationInserter : public kir::MutableIrVisitor { +class AllocationInserter : public kir::ExprMutator { private: + using kir::ExprMutator::handle; + + // Expanded version of BasicAllocInfo in lower_utils.h helps to track + // additional information struct AllocationInformation { // The for loop that the initialization of this allocation must be // placed in, nullptr if not within a loop @@ -26,7 +28,7 @@ class AllocationInserter : public kir::MutableIrVisitor { // The expression that the initialization of this allocation must // be placed before - kir::Expr* init_place_before = nullptr; + Expr* init_place_before = nullptr; // Keep track of the actual allocation loop. This can be different // from init_for_loop only with unswitched shared memory allocations, @@ -37,143 +39,96 @@ class AllocationInserter : public kir::MutableIrVisitor { // The expression that this allocation must be placed // before. Similar to alloc_for_loop, this is different from // init_place_before only with unswitched shared memory allocations. - kir::Expr* alloc_place_before = nullptr; + Expr* alloc_place_before = nullptr; // The allocation position relative to buffer size_t alloc_pos = 0; // The buffer this allocation is for - kir::TensorView* buffer = nullptr; - - // The allocation expression - kir::Allocate* alloc_expr = nullptr; - - // Initialization - kir::Expr* init_expr = nullptr; + TensorView* buffer = nullptr; // Info to transfer to GPU lower bool has_halo = false; // Local Iterdomains that this allocation covers - std::unique_ptr> allocation_domains; + std::unique_ptr> allocation_domains; }; // Find allocation point - void findAllocationPosition(AllocationInformation& info, kir::Expr* expr) { + // Fills info.buffer, info.alloc_pos, info.init_for_loop, + // info.init_place_before, info.alloc_for_loop, info.alloc_place_before + void fillAllocationInformation(AllocationInformation& info, Expr* expr) { size_t alloc_pos = 0; kir::ForLoop* init_for_loop = nullptr; - auto fuser_tv = info.buffer->fuserTv(); size_t fl_idx_next = 0; - - bool outer_alloc_found = false; - kir::ForLoop* alloc_for_loop = nullptr; - size_t alloc_fl_idx_next = 0; - - for (auto fl : for_loops) { - if (alloc_pos == fuser_tv->getComputeAtPosition()) { - break; - } - - if (fuser_tv->axis(alloc_pos)->isReduction()) { - const auto outputs = - FusionGuard::getCurFusion()->getTerminatingOutputs(); - TORCH_INTERNAL_ASSERT( - std::find(outputs.begin(), outputs.end(), fuser_tv) != - outputs.end(), - "Invalid computeAt of T", - fuser_tv->name(), - ". A reducation axis is detected within computeAt axes even though it is not an output tensor."); - break; - } - - auto fl_id = fl->iter_domain(); - - if (fl_id->parallelType() == ParallelType::Unroll) { - break; - } - - // Shared memory must be allocated outside of unswitched - // domains. See issue #1133. - if (fl_id->parallelType() == ParallelType::Unswitch && - fuser_tv->getMemoryType() == MemoryType::Shared) { - outer_alloc_found = true; - } - - auto local_id = gpu_lower->lowerValue(fuser_tv->axis(alloc_pos)) - ->as(); - - if (gpu_lower->caLoopMap().areMapped(local_id, fl_id)) { - alloc_pos++; - } - - init_for_loop = fl; - ++fl_idx_next; - - if (!outer_alloc_found) { - alloc_for_loop = fl; - ++alloc_fl_idx_next; + auto loop_alloc_info = + loop_utils::getAllocInformation(info.buffer, for_loops_); + + info.init_for_loop = loop_alloc_info.init_for_loop; + info.alloc_for_loop = loop_alloc_info.alloc_for_loop; + info.alloc_pos = loop_alloc_info.alloc_pos; + + auto next_fl = [](kir::ForLoop* fl, const std::vector fls) { + for (auto i : c10::irange(fls.size())) { + if (fl == fls[i]) { + if (i + 1 < fls.size()) { + return fls[i + 1]; + } + } } - } - - info.alloc_pos = alloc_pos; - info.init_for_loop = init_for_loop; + TORCH_INTERNAL_ASSERT(false, "Could not find desired loop."); + }; if (info.init_for_loop == nullptr) { - info.init_place_before = for_loops.size() > 0 ? for_loops[0] : expr; + info.init_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr; } else { - if (info.init_for_loop == for_loops.back()) { + if (info.init_for_loop == for_loops_.back()) { // Inline allocation, place before expr info.init_place_before = expr; } else { // Place allocation after the last computeAt axis // TODO: may be more efficient to place before the first non-computeAt // axis - info.init_place_before = for_loops.at(fl_idx_next); + info.init_place_before = next_fl(info.init_for_loop, for_loops_); } } // Set the allocation loop and the place_before expression in the // same way as the initialization loop and place_before expression - if (!outer_alloc_found) { + if (info.alloc_for_loop == info.init_for_loop) { info.alloc_for_loop = info.init_for_loop; info.alloc_place_before = info.init_place_before; } else { - info.alloc_for_loop = alloc_for_loop; if (info.alloc_for_loop == nullptr) { - info.alloc_place_before = for_loops.size() > 0 ? for_loops[0] : expr; + info.alloc_place_before = for_loops_.size() > 0 ? for_loops_[0] : expr; } else { // Since there must be an inner unswitched domain, // alloc_for_loop should never be the inner-most loop. - TORCH_INTERNAL_ASSERT(info.alloc_for_loop != for_loops.back()); - info.alloc_place_before = for_loops.at(alloc_fl_idx_next); + TORCH_INTERNAL_ASSERT(info.alloc_for_loop != for_loops_.back()); + info.alloc_place_before = next_fl(info.alloc_for_loop, for_loops_); } } } // Create initialization expression if init_val is non-null. - void createInitExpr(AllocationInformation& info, kir::Val* init_val) { + Expr* createInitExpr(AllocationInformation& info, Val* init_val) { if (init_val == nullptr) { - info.init_expr = nullptr; - return; + return nullptr; } - auto fuser_tv = info.buffer->fuserTv(); - - std::vector init_dims; - for (const auto axis_i : c10::irange(info.alloc_pos, fuser_tv->nDims())) { - if (info.buffer->fuserTv()->axis(axis_i)->isReduction() || - info.buffer->fuserTv()->axis(axis_i)->isBroadcast()) { + std::vector init_dims; + for (const auto axis_i : + c10::irange(info.alloc_pos, info.buffer->nDims())) { + if (info.buffer->axis(axis_i)->isReduction() || + info.buffer->axis(axis_i)->isBroadcast()) { continue; } - auto concrete_id = - gpu_lower - ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID( - fuser_tv->axis(axis_i))) - ->as(); + auto concrete_id = gpu_lower->caParallelMap().getConcreteMappedID( + info.buffer->axis(axis_i)); init_dims.push_back(concrete_id); } - kir::Expr* init_expr = ir_builder.create( - UnaryOpType::Set, info.buffer, init_val); + Expr* init_expr = + IrBuilder::create(UnaryOpType::Set, info.buffer, init_val); for (auto init_loop_it = init_dims.rbegin(); init_loop_it != init_dims.rend(); ++init_loop_it) { @@ -181,9 +136,9 @@ class AllocationInserter : public kir::MutableIrVisitor { kir::ForLoop* new_loop = nullptr; auto extent_with_halo = gpu_lower->haloInfo().getExtent(id); if (extent_with_halo) { - new_loop = ir_builder.create( + new_loop = IrBuilder::create( id, - ir_builder.create(c10::nullopt), + IrBuilder::create(c10::nullopt), nullptr, extent_with_halo, nullptr, @@ -191,31 +146,33 @@ class AllocationInserter : public kir::MutableIrVisitor { nullptr, false); } else { - new_loop = ir_builder.create(id); + new_loop = IrBuilder::create(id); } new_loop->body().push_back(init_expr); init_expr = new_loop; } - info.init_expr = init_expr; + return init_expr; } - std::vector getGlobalAllocationSizes(AllocationInformation& info) { + std::vector getGlobalAllocationSizes(AllocationInformation& info) { const auto& domain = info.buffer->domain(); - const auto& maybe_rfactor_domain = - domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain(); + const auto& maybe_rfactor_domain = domain->hasRFactor() + ? domain->getRFactorDomain() + : domain->getRootDomain(); - std::vector alloc_dims; + std::vector alloc_dims; for (const auto id : maybe_rfactor_domain) { if (id->isReduction() || id->isStride() || - id->iterType() == IterType::BroadcastWithoutStride) { + id->getIterType() == IterType::BroadcastWithoutStride) { continue; } auto extent = id->extent(); // Use halo-extended extent if found auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id); if (halo_extent.hasHalo()) { - extent = ir_builder.addExpr(extent, halo_extent.width()); + extent = IrBuilder::addExpr( + extent, IrBuilder::create(halo_extent.width())); } alloc_dims.push_back(extent); } @@ -244,7 +201,7 @@ class AllocationInserter : public kir::MutableIrVisitor { // fall back to the leaf-based allocation. // // See the FusionShiftDoubleSplit test for an example case. - std::vector getNonGlobalAllocExprWithHalo( + std::vector getNonGlobalAllocExprWithHalo( TensorView* tv, const std::vector& alloc_domains) { std::vector start_vals; @@ -255,18 +212,18 @@ class AllocationInserter : public kir::MutableIrVisitor { [](IterDomain* dom) { return dom->as(); }); // Get all exprs involved in generating the allocation IDs - auto exprs = ExprSort::getExprs(tv->fusion(), start_vals); + auto exprs = StmtSort::getExprs(tv->fusion(), start_vals); // Get the halo extent if found auto getExtent = [this](IterDomain* id) { auto extent = gpu_lower->haloInfo().getExtent(id); if (extent == nullptr) { - extent = gpu_lower->lowerValue(id->extent()); + extent = id->extent(); } return extent; }; - std::unordered_map known_extents; + std::unordered_map known_extents; // IterDomains that are allocated fully. For example, if an ID is // split and only one of them is used for allocation, that's not @@ -314,7 +271,7 @@ class AllocationInserter : public kir::MutableIrVisitor { } else { known_extents.insert( {split->in(), - ir_builder.mulExpr(outer_it->second, inner_it->second)}); + IrBuilder::mulExpr(outer_it->second, inner_it->second)}); } known_extents.erase(inner_it); known_extents.erase(outer_it); @@ -330,7 +287,7 @@ class AllocationInserter : public kir::MutableIrVisitor { } } - std::vector alloc_dims; + std::vector alloc_dims; for (auto root_axis : tv->getRootDomain()) { auto it = known_extents.find(root_axis); @@ -355,24 +312,22 @@ class AllocationInserter : public kir::MutableIrVisitor { return alloc_dims; } - std::vector getNonGlobalAllocExpr(AllocationInformation& info) { - auto fuser_tv = info.buffer->fuserTv(); - const auto memory_type = info.buffer->memoryType(); + std::vector getNonGlobalAllocExpr(AllocationInformation& info) { + const auto memory_type = info.buffer->getMemoryType(); TORCH_INTERNAL_ASSERT( memory_type != MemoryType::Global, "Invalid memory type: ", memory_type); - std::vector alloc_dims; + std::vector alloc_dims; bool has_halo = false; std::vector alloc_domains; - info.allocation_domains = std::make_unique>(); + info.allocation_domains = std::make_unique>(); - for (const auto axis_i : c10::irange(fuser_tv->nDims())) { - const auto local_id = - gpu_lower->lowerValue(fuser_tv->axis(axis_i))->as(); + for (const auto axis_i : c10::irange(info.buffer->nDims())) { + const auto local_id = info.buffer->axis(axis_i); // Don't use reduction/stride/broadcast axis in the allocation // computation @@ -381,16 +336,14 @@ class AllocationInserter : public kir::MutableIrVisitor { continue; } - auto concrete_id = - gpu_lower - ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID( - fuser_tv->axis(axis_i))) - ->as(); + auto concrete_id = gpu_lower->caParallelMap().getConcreteMappedID( + info.buffer->axis(axis_i)); const bool is_block_dim = - isParallelTypeBlockDim(concrete_id->parallelType()); + isParallelTypeBlockDim(concrete_id->getParallelType()); const bool is_thread_dim = - isParallelTypeThreadDim(concrete_id->parallelType()); - const bool is_thread = isParallelTypeThread(concrete_id->parallelType()); + isParallelTypeThreadDim(concrete_id->getParallelType()); + const bool is_thread = + isParallelTypeThread(concrete_id->getParallelType()); if (axis_i < info.alloc_pos) { // Even when the axis is outside the allocation position, if the @@ -403,7 +356,7 @@ class AllocationInserter : public kir::MutableIrVisitor { (memory_type == MemoryType::Global && is_thread))) { continue; } - alloc_domains.push_back(fuser_tv->axis(axis_i)); + alloc_domains.push_back(info.buffer->axis(axis_i)); } else { if ( // If shared memory, don't use any IDs bound to a grid dimension @@ -413,12 +366,13 @@ class AllocationInserter : public kir::MutableIrVisitor { (memory_type == MemoryType::Local && is_thread)) { continue; } - alloc_domains.push_back(fuser_tv->axis(axis_i)); + alloc_domains.push_back(info.buffer->axis(axis_i)); } auto extent = concrete_id->extent(); - if (gpu_lower->haloInfo().getExtent(fuser_tv->axis(axis_i)) != nullptr) { + if (gpu_lower->haloInfo().getExtent(info.buffer->axis(axis_i)) != + nullptr) { has_halo = true; } @@ -430,20 +384,19 @@ class AllocationInserter : public kir::MutableIrVisitor { // the halo extents from leaf IDs to root IDs if (has_halo) { info.has_halo = true; - return getNonGlobalAllocExprWithHalo(fuser_tv, alloc_domains); + return getNonGlobalAllocExprWithHalo(info.buffer, alloc_domains); } return alloc_dims; } - void createAllocExpr(AllocationInformation& info, bool is_output) { + kir::Allocate* createAllocExpr(AllocationInformation& info, bool is_output) { if (is_output) { - info.alloc_expr = nullptr; - return; + return nullptr; } - std::vector alloc_dims; - const MemoryType memory_type = info.buffer->memoryType(); + std::vector alloc_dims; + const MemoryType memory_type = info.buffer->getMemoryType(); if (memory_type == MemoryType::Global) { alloc_dims = getGlobalAllocationSizes(info); @@ -453,60 +406,74 @@ class AllocationInserter : public kir::MutableIrVisitor { if (alloc_dims.size() == 0 && info.buffer->domain()->noReductions().size() != 0) { - alloc_dims.push_back(ir_builder.create(1)); + alloc_dims.push_back(info.buffer->container()->oneVal()); + } + + // Double the allocation size if double-buffered. Record the + // original size for indexing. + if (info.buffer->isDoubleBuffered()) { + Val* original_alloc_size = nullptr; + for (auto alloc_dim : alloc_dims) { + if (original_alloc_size == nullptr) { + original_alloc_size = alloc_dim; + } else { + original_alloc_size = + IrBuilder::mulExpr(original_alloc_size, alloc_dim); + } + } + GpuLower::current()->doubleBufferInfo().setOriginalAllocSize( + info.buffer, original_alloc_size); + alloc_dims.push_back(IrBuilder::create(2)); } // Create the allocation node - info.alloc_expr = ir_builder.create( - info.buffer, info.buffer->memoryType(), alloc_dims); + return IrBuilder::create( + info.buffer, info.buffer->getMemoryType(), alloc_dims); } - void handle(kir::Expr* expr) { - if (!ir_utils::isTVOp(expr) || expr->isA()) { - expr->accept(this); + void handle(Expr* expr) override { + if (!ir_utils::isTvOp(expr) || expr->isA()) { + ExprMutator::handle(expr); return; } // // Found where the allocation needs to be inserted for (auto out : expr->outputs()) { - if (!out->isA()) { + if (!out->isA()) { continue; } - auto out_tv = out->as(); - auto default_val = - gpu_lower->predicateElimination().getInitValue(out_tv->fuserTv()); + auto out_tv = out->as(); + auto default_val = gpu_lower->predicateElimination().getInitValue(out_tv); - kir::Val* init = nullptr; - if (expr->isA() && out_tv->fuserTv()->hasReduction()) { + Val* init = nullptr; + if (expr->isA() && out_tv->hasReduction()) { TORCH_INTERNAL_ASSERT( default_val == nullptr, "Reduction should not have a default initialization value for predicate elimination."); - init = expr->as()->init(); - } else if (expr->isA()) { + init = expr->as()->init(); + } else if (expr->isA()) { TORCH_INTERNAL_ASSERT( default_val == nullptr, "Welford should not have a default initialization value for predicate elimination."); - const auto welford = expr->as(); - if (out->id() == welford->outVar()->id()) { - init = welford->initVar() == nullptr - ? ir_builder.create(0) - : welford->initVar(); - } else if (out->id() == welford->outAvg()->id()) { - init = welford->initAvg() == nullptr - ? ir_builder.create(0) - : welford->initAvg(); + const auto welford = expr->as(); + if (out->name() == welford->outVar()->name()) { + init = welford->initVar() == nullptr ? IrBuilder::create(0) + : welford->initVar(); + } else if (out->name() == welford->outAvg()->name()) { + init = welford->initAvg() == nullptr ? IrBuilder::create(0) + : welford->initAvg(); } else { TORCH_INTERNAL_ASSERT( - out->id() == welford->outN()->id(), "Unreachable"); + out->name() == welford->outN()->name(), "Unreachable"); init = welford->initN(); } } else if (default_val != nullptr) { init = default_val; } - const bool is_output = gpu_lower->kernel()->isOutput(out); + const bool is_output = out->isFusionOutput(); // Don't need to alloc outputs, and if we don't need to initialize we're // done. @@ -516,150 +483,91 @@ class AllocationInserter : public kir::MutableIrVisitor { AllocationInformation allocation; allocation.buffer = out_tv; - findAllocationPosition(allocation, expr); - createAllocExpr(allocation, is_output); - createInitExpr(allocation, init); + fillAllocationInformation(allocation, expr); + + auto alloc_expr = createAllocExpr(allocation, is_output); + auto init_expr = createInitExpr(allocation, init); // Write information to GPULower - writeInfoToGPULower(allocation); + writeInfoToGPULower(allocation, alloc_expr); + + // Register allocations before initializations to keep them in the right + // order + if (alloc_expr != nullptr) { + if (allocation.buffer->getMemoryType() == MemoryType::Shared) { + // Shared allocations go at the begining of scope + TORCH_INTERNAL_ASSERT(!exprs_.empty()); + registerInsertBefore(exprs_[0], alloc_expr, nullptr); + } else { + TORCH_INTERNAL_ASSERT(allocation.alloc_place_before != nullptr); + kir::Scope* scope = allocation.alloc_for_loop == nullptr + ? nullptr + : &allocation.alloc_for_loop->body(); + registerInsertBefore( + allocation.alloc_place_before, alloc_expr, scope); + } + } - allocs.push_back(std::move(allocation)); + if (init_expr != nullptr) { + TORCH_INTERNAL_ASSERT(allocation.init_place_before != nullptr); + kir::Scope* scope = allocation.init_for_loop == nullptr + ? nullptr + : &allocation.init_for_loop->body(); + registerInsertBefore(allocation.init_place_before, init_expr, scope); + } } } - void writeInfoToGPULower(const AllocationInformation& allocation) { + // Sends alloc_expr, info.has_halo, info.allocation_domains to GpuLower + void writeInfoToGPULower( + const AllocationInformation& allocation, + kir::Allocate* alloc_expr) { auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); - if (allocation.alloc_expr == nullptr) { + if (alloc_expr == nullptr) { // Skip output allocation. return; } TORCH_INTERNAL_ASSERT( - !lower_alloc_info_map.count(allocation.alloc_expr), + !lower_alloc_info_map.count(alloc_expr), "duplicated allocation info entry"); // Create info entry for GPULower auto lower_alloc_info_ptr = std::make_unique(); - lower_alloc_info_ptr->alloc_expr = allocation.alloc_expr; + lower_alloc_info_ptr->alloc_expr = alloc_expr; lower_alloc_info_ptr->has_halo = allocation.has_halo; if (allocation.allocation_domains) { lower_alloc_info_ptr->alloc_domains = *(allocation.allocation_domains); } // Write entry to the stored map - lower_alloc_info_map[allocation.alloc_expr] = - std::move(lower_alloc_info_ptr); + lower_alloc_info_map[alloc_expr] = std::move(lower_alloc_info_ptr); } - void visit(kir::ForLoop* fl) final { - for_loops.push_back(fl); - // Modifying in place, make a copy of the vector - const std::vector exprs = fl->body().exprs(); - for (auto expr : exprs) { - handle(expr); - } - for_loops.pop_back(); - } - - void visit(kir::IfThenElse*) final { + void handle(kir::IfThenElse*) final { TORCH_INTERNAL_ASSERT( false, "Pass does not support conditional statements, ", "this pass should be run before any conditionals are placed in code."); } - AllocationInserter(std::vector _loop_nests) - : loop_nests_(std::move(_loop_nests)), - gpu_lower(GpuLower::current()), - ir_builder(gpu_lower->kernel()) { - // Compute all allocations - const std::vector exprs = loop_nests_; - for (auto expr : exprs) { - handle(expr); - } - - // First, place allocations of dynamic smem tensors at the very - // beginning of the expr list. Traverse backward as they should be - // placed in topological order. - for (auto it = allocs.rbegin(); it != allocs.rend(); ++it) { - const auto& alloc = *it; - if (alloc.alloc_expr == nullptr) { - continue; - } - // Dynamic smem exprs need to be at the begining of the kernel outside for - // loops - if (alloc.buffer->memoryType() == MemoryType::Shared && - !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) { - loop_nests_.insert(loop_nests_.begin(), alloc.alloc_expr); - } - } - - // Place the remaining allocations. - for (const auto& alloc : allocs) { - if (alloc.alloc_expr == nullptr) { - continue; - } - if (alloc.buffer->memoryType() == MemoryType::Shared && - !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) { - continue; - } - if (alloc.alloc_for_loop == nullptr) { - auto place_before_it = std::find( - loop_nests_.begin(), loop_nests_.end(), alloc.alloc_place_before); - TORCH_INTERNAL_ASSERT( - place_before_it != loop_nests_.end(), - "Could not figure out where to place allocation. ", - "Use of the buffer, ", - toString(alloc.buffer), - ", could not be found.", - toString(alloc.alloc_place_before)); - loop_nests_.insert(place_before_it, alloc.alloc_expr); - } else { - alloc.alloc_for_loop->body().insert_before( - alloc.alloc_place_before, alloc.alloc_expr); - } - } - - // Now that allocations are in place, place the initializations - for (const auto& alloc : allocs) { - if (alloc.init_expr == nullptr) { - continue; - } - if (alloc.init_for_loop == nullptr) { - auto place_before_it = std::find( - loop_nests_.begin(), loop_nests_.end(), alloc.init_place_before); - // Don't need a check here as if the allocation placement succeeded - // this will too - loop_nests_.insert(place_before_it, alloc.init_expr); - } else { - alloc.init_for_loop->body().insert_before( - alloc.init_place_before, alloc.init_expr); - } - } + AllocationInserter(const std::vector& exprs) + : gpu_lower(GpuLower::current()) { + kir::ExprMutator::traverseAndInsert(exprs); } private: - std::deque allocs; - - std::vector for_loops; - - std::vector loop_nests_; - GpuLower* gpu_lower; - kir::IrBuilder ir_builder; - public: - static std::vector insert( - const std::vector& loop_nests) { - AllocationInserter inserter(loop_nests); - return inserter.loop_nests_; + static std::vector insert(const std::vector& exprs) { + AllocationInserter inserter(exprs); + return inserter.exprs_; } }; } // namespace -std::vector insertAllocations( - const std::vector& exprs) { +std::vector insertAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertAllocations"); return AllocationInserter::insert(exprs); } diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.h b/torch/csrc/jit/codegen/cuda/lower_allocation.h index bc0344ca19f..45ebeac03f7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.h +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.h @@ -1,8 +1,7 @@ #pragma once -#include +#include -#include #include #include @@ -17,7 +16,7 @@ namespace cuda { //! logic duplication struct LocalAllocationInfo { kir::Allocate* alloc_expr = nullptr; - std::vector alloc_domains; + std::vector alloc_domains; bool has_halo = false; }; @@ -25,7 +24,7 @@ using LocalAllocationInfoMap = std::unordered_map>; //! Insert buffer allocations -std::vector insertAllocations(const std::vector& exprs); +std::vector insertAllocations(const std::vector& exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp new file mode 100644 index 00000000000..c8110413de7 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -0,0 +1,508 @@ +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +unsigned int getDoubleBufferAxisPosition(const TensorView* tv) { + // Double-buffering prefetches the next subregion of the tensor by + // doubling the allocation. The subregion is defined by the axes + // at the CA position till the inner-most position. There must be + // at least one axis that is outside (left) of the CA position, + // which defines the loop where prefetching is applied. Therefore, + // the CA position must be larger than 0. + + TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0); + + // Unroll must not exist outside of double-buffer axis + auto first_unroll_it = std::find_if( + tv->domain()->domain().begin(), + tv->domain()->domain().end(), + [](const auto axis) { + return axis->getParallelType() == ParallelType::Unroll; + }); + + const int first_unroll_pos = + std::distance(tv->domain()->domain().begin(), first_unroll_it); + + const int unroll_or_ca_pos = + std::min((int)tv->getComputeAtPosition(), first_unroll_pos); + + TORCH_INTERNAL_ASSERT( + unroll_or_ca_pos > 0, + "Invalid tensor to double-buffer. Valid double buffer axis not found due to Unroll. ", + tv->toString()); + + int valid_pos = -1; + // Skip parallelized or broadcast axes + for (int i = unroll_or_ca_pos - 1; i >= 0; --i) { + auto pt = tv->axis(i)->getParallelType(); + if (!isParallelTypeThread(pt) && !tv->axis(i)->isBroadcast()) { + valid_pos = i; + break; + } + } + + TORCH_INTERNAL_ASSERT( + valid_pos >= 0, + "Invalid tensor to double-buffer. Valid double buffer axis not found. ", + tv->toString()); + + return valid_pos; +} + +IterDomain* getDoubleBufferAxis(const TensorView* tv) { + return tv->axis((int)getDoubleBufferAxisPosition(tv)); +} + +void validateDoubleBufferedTensor(const TensorView* tv) { + auto double_buffer_pos = getDoubleBufferAxisPosition(tv); + + // Like vectorization, only UnaryOp::Set with another TensorView is + // considered. + auto def = tv->definition(); + TORCH_INTERNAL_ASSERT( + def->isA() && + def->as()->getUnaryOpType() == UnaryOpType::Set, + "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: ", + def->toString()); + + TORCH_INTERNAL_ASSERT( + def->as()->in()->isA(), + "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set with TensorView is supported: ", + def->toString()); + + // Require the producer tensor to have been computed entirely for + // the double-buffering loop. Otherwise, the producer itself would + // also need to be double-bufferred. + auto producer = def->as()->in()->as(); + TORCH_INTERNAL_ASSERT( + producer->getComputeAtPosition() <= double_buffer_pos, + "Invalid tensor to double-buffer. The computeAt position of the producer tensor must be moved left: ", + producer->toString()); + + // Not strictly necessary, but only gmem -> smem or local and smem -> local + // are allowed. + const auto p_mem_type = producer->getMemoryType(); + const auto c_mem_type = tv->getMemoryType(); + TORCH_INTERNAL_ASSERT( + (p_mem_type == MemoryType::Global && + (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || + (p_mem_type == MemoryType::Shared && c_mem_type == MemoryType::Local), + "Invalid tensor to double-buffer: ", + tv->toString(), + ". Producer memory type: ", + p_mem_type, + ". Consumer memory type: ", + c_mem_type); + + return; +} + +namespace { + +// Initial inspection of a fusion to find and validate double buffered tensors +class DoubleBufferFusionInspector : private IterVisitor { + public: + DoubleBufferFusionInspector(Fusion* fusion, DoubleBufferInfo& db_info) + : db_info_(db_info) { + traverse(fusion); + } + + private: + using IterVisitor::handle; + + void handle(TensorView* tv) final { + if (!tv->isDoubleBuffered()) { + return; + } + + validateDoubleBufferedTensor(tv); + + auto db_axis = getDoubleBufferAxis(tv); + + db_info_.setDoubleBufferAxis(tv, db_axis); + } + + private: + DoubleBufferInfo& db_info_; +}; + +// The type of replicated double-buffer loops +enum class LoopType { Prologue, Main, Epilogue }; + +// The epilogue loop is only created when the producer of a double +// buffer tensor is on smem, in which case it would otherwise require +// an additional predicate to guard buffer overruns. When it's on +// gmem, that isn't the case, so it does not need to create an +// epilogue loop. +bool requireEpilogue(const std::vector& exprs) { + return std::any_of(exprs.begin(), exprs.end(), [](const UnaryOp* uop) { + return uop->in()->as()->getMemoryType() == MemoryType::Shared; + }); +} + +// Replicates double buffer loops for Prologue, Main, and +// Epilogue. Prologue only copies the load expressions of double +// buffered tensors, whereas Epilogue does any expression other than +// the loads. Main copies everything. +class DoubleBufferLoopCloner : public kir::IrVisitor { + public: + static kir::ForLoop* clone( + kir::ForLoop* double_buffer_loop, + const std::vector& double_buffer_load_exprs, + LoopType loop_type) { + DoubleBufferLoopCloner cloner( + double_buffer_loop, double_buffer_load_exprs, loop_type); + cloner.clone(); + return cloner.cloned_top_level_loop_; + } + + private: + DoubleBufferLoopCloner( + kir::ForLoop* double_buffer_loop, + const std::vector& double_buffer_load_exprs, + LoopType loop_type) + : double_buffer_loop_(double_buffer_loop), + double_buffer_load_exprs_(double_buffer_load_exprs), + loop_type_(loop_type) {} + + using kir::IrVisitor::handle; + + void clone() { + const auto gpu_lower = GpuLower::current(); + + // Cloning the double buffer loop as follows: + // + // Prologue: 0 to 1 + // Main: 0 to (extent-1) + // Epilogue: (extent-1) to extent + + auto index = IrBuilder::create(c10::nullopt); + auto start = double_buffer_loop_->start(); + auto stop = double_buffer_loop_->stop(); + + if (loop_type_ == LoopType::Prologue) { + TORCH_INTERNAL_ASSERT(start->isZeroInt()); + stop = gpu_lower->kernel()->oneVal(); + } else if ( + loop_type_ == LoopType::Main && + requireEpilogue(double_buffer_load_exprs_)) { + stop = IrBuilder::subExpr( + double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); + } else if (loop_type_ == LoopType::Epilogue) { + TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_)); + start = IrBuilder::subExpr( + double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); + } + + cloned_top_level_loop_ = IrBuilder::create( + double_buffer_loop_->iter_domain(), + index, + start, + stop, + gpu_lower->kernel()->oneVal(), + false, + nullptr, + double_buffer_loop_->isUnrollRequired()); + + handle(double_buffer_loop_); + } + + void handle(kir::ForLoop* fl) final { + const auto gpu_lower = GpuLower::current(); + + kir::ForLoop* cloned_loop = fl == double_buffer_loop_ + ? cloned_top_level_loop_ + : IrBuilder::create(fl); + + cloned_scopes_.push_back(&cloned_loop->body()); + + kir::IrVisitor::handle(fl); + + cloned_scopes_.pop_back(); + + // Add the cloned loop into the parent loop body only when the + // cloned loop contains expressions. + if (!cloned_loop->body().empty() && !cloned_scopes_.empty()) { + cloned_scopes_.back()->push_back(cloned_loop); + } + } + + void handle(kir::IfThenElse* ite) final { + TORCH_INTERNAL_ASSERT(false, "No IfThenElse should exist yet"); + } + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + TORCH_INTERNAL_ASSERT(!cloned_scopes_.empty()); + + if (loop_type_ == LoopType::Main) { + cloned_scopes_.back()->push_back(expr); + return; + } + + // In Prologue and Epilogue, either load expressions or anything + // else are copied. Note that there can be multiple exprs defining + // double buffered TVs (e.g., buffer initialization). + + auto out_tv = ir_utils::getTvOutput(expr); + const auto is_double_buffer_load_expr = std::any_of( + double_buffer_load_exprs_.begin(), + double_buffer_load_exprs_.end(), + [out_tv](const auto load_expr) { + auto double_buffer_tv = ir_utils::getTvOutput(load_expr); + TORCH_INTERNAL_ASSERT(double_buffer_tv != nullptr); + return out_tv == double_buffer_tv; + }); + if ((loop_type_ == LoopType::Prologue && is_double_buffer_load_expr) || + (loop_type_ == LoopType::Epilogue && !is_double_buffer_load_expr)) { + cloned_scopes_.back()->push_back(expr); + } + } + + private: + kir::ForLoop* double_buffer_loop_ = nullptr; + const std::vector& double_buffer_load_exprs_; + const LoopType loop_type_; + + kir::ForLoop* cloned_top_level_loop_ = nullptr; + std::deque cloned_scopes_; +}; + +using InsertionInfo = std::unordered_map>; + +// Traverse lowered loop-nests and find all double buffer loops and +// associated load expressions. +class DoubleBufferLoopNestInspector : private kir::IrVisitor { + public: + static InsertionInfo run(const std::vector& exprs) { + DoubleBufferLoopNestInspector inspector(exprs); + return inspector.insertion_info_; + } + + private: + DoubleBufferLoopNestInspector(const std::vector& exprs) { + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(UnaryOp* uop) final { + const auto gpu_lower = GpuLower::current(); + + auto out_tv = ir_utils::getTvOutput(uop); + + if (out_tv == nullptr) { + return; + } + + // Ignore init loop + if (!out_tv->isDoubleBuffered() || !uop->in()->isA()) { + return; + } + + auto double_buffer_loop = + gpu_lower->doubleBufferInfo().getDoubleBufferLoop(out_tv, for_loops_); + + TORCH_INTERNAL_ASSERT( + double_buffer_loop != nullptr, + "No double buffer loop found for a double buffered tensor: ", + out_tv->toString()); + + validateDoubleBufferLoop(double_buffer_loop); + + insertion_info_[double_buffer_loop].push_back(uop); + } + + static void validateDoubleBufferLoop(kir::ForLoop* loop) { + TORCH_INTERNAL_ASSERT( + loop->start()->isZeroInt(), "Unsupported loop: ", loop->toString()); + TORCH_INTERNAL_ASSERT( + loop->step()->isOneInt(), "Unsupported loop: ", loop->toString()); + TORCH_INTERNAL_ASSERT( + !loop->vectorize(), + "Vectorized loop should not be the allocation loop for double-buffered tensor: ", + loop->toString()); + TORCH_INTERNAL_ASSERT( + !loop->vectorize_shift(), + "Vectorize shift loop should not be the allocation loop for double-buffered tensor: ", + loop->toString()); + } + + InsertionInfo insertion_info_; +}; + +// Apply double buffering transformations +class DoubleBufferInserter : private kir::ExprMutator { + public: + // When there exist multiple double buffer loops, apply + // transformations to inner-most loops first. A single ExprMutator + // pass can only process one loop. + static std::vector run( + const std::vector& exprs, + InsertionInfo insertion_info) { + auto inserted_exprs = exprs; + while (!insertion_info.empty()) { + DoubleBufferInserter inserter(inserted_exprs, insertion_info); + inserted_exprs = inserter.exprs_; + } + return inserted_exprs; + } + + private: + DoubleBufferInserter( + const std::vector& exprs, + InsertionInfo& insertion_info) + : insertion_info_(insertion_info) { + auto num_double_buffer_loops = insertion_info.size(); + traverseAndInsert(exprs); + TORCH_INTERNAL_ASSERT(processed_loop_ != nullptr); + TORCH_INTERNAL_ASSERT(insertion_info.size() == num_double_buffer_loops - 1); + } + + using kir::ExprMutator::handle; + + void handle(kir::ForLoop* loop) final { + kir::ExprMutator::handle(loop); + + // If another loop is already taken care of, no more loop should + // be done in the same pass + if (processed_loop_ != nullptr) { + return; + } + + auto it = insertion_info_.find(loop); + if (it == insertion_info_.end()) { + return; + } + + insert(loop, it->second); + processed_loop_ = loop; + insertion_info_.erase(loop); + } + + void insert( + kir::ForLoop* double_buffer_loop, + const std::vector& loads) { + auto prologue_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, LoopType::Prologue); + registerInsertBefore(double_buffer_loop, prologue_loop); + + auto write_to_smem = + std::any_of(loads.begin(), loads.end(), [](const UnaryOp* uop) { + return uop->out()->as()->getMemoryType() == + MemoryType::Shared; + }); + + // RAW sync is not inserted for double buffered tensors. The only + // exception is the prologue load. + if (write_to_smem) { + auto sync = IrBuilder::create(); + registerInsertBefore(double_buffer_loop, sync); + } + + auto main_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, LoopType::Main); + registerReplace(double_buffer_loop, main_loop); + + if (requireEpilogue(loads)) { + auto epilogue_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, LoopType::Epilogue); + registerInsertAfter(double_buffer_loop, epilogue_loop); + } + } + + private: + InsertionInfo& insertion_info_; + kir::ForLoop* processed_loop_ = nullptr; +}; + +} // namespace + +void DoubleBufferInfo::build(Fusion* fusion) { + DoubleBufferFusionInspector inspector(fusion, *this); +} + +DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) { + TORCH_INTERNAL_ASSERT( + tv->isDoubleBuffered(), "Not a double-buffered tensor: ", tv->toString()); + return map_[tv]; +} + +void DoubleBufferInfo::setDoubleBufferAxis( + const TensorView* tv, + IterDomain* axis) { + getTvInfo(tv).double_buffer_axis = axis; +} + +IterDomain* DoubleBufferInfo::getDoubleBufferAxis(const TensorView* tv) { + if (!tv->isDoubleBuffered()) { + return nullptr; + } + + return getTvInfo(tv).double_buffer_axis; +} + +kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop( + IterDomain* axis, + const std::vector& loops, + bool ignore_prologue) { + auto loop_it = std::find_if(loops.begin(), loops.end(), [&](const auto loop) { + return GpuLower::current()->caIndexMap().areMapped( + loop->iter_domain(), axis) && + (!ignore_prologue || !loop->stop()->isOneInt()); + }); + + if (loop_it != loops.end()) { + return *loop_it; + } else { + return nullptr; + } +} + +kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop( + const TensorView* tv, + const std::vector& loops, + bool ignore_prologue) { + auto axis = getDoubleBufferAxis(tv); + + if (axis == nullptr) { + return nullptr; + } + + return getDoubleBufferLoop(axis, loops, ignore_prologue); +} + +void DoubleBufferInfo::setOriginalAllocSize( + const TensorView* tv, + Val* original_alloc_size) { + getTvInfo(tv).original_alloc_size = original_alloc_size; +} + +Val* DoubleBufferInfo::getOriginalAllocSize(const TensorView* tv) { + if (!tv->isDoubleBuffered()) { + return nullptr; + } + + return getTvInfo(tv).original_alloc_size; +} + +std::vector DoubleBufferPass::run(const std::vector& exprs) { + auto insertion_info = DoubleBufferLoopNestInspector::run(exprs); + return DoubleBufferInserter::run(exprs, insertion_info); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h new file mode 100644 index 00000000000..96bc247f4ff --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -0,0 +1,142 @@ +#pragma once + +#include + +#include +#include +#include + +// Double buffering a tensor doubles its allocation size and uses two +// buffers to facilitate computation and memory access +// overlapping. The basic form of code looks like as follows: +// +// Before: +// for i +// x[S]; // allocation +// for j: +// x[j] = y[i, j] +// for j: +// ... = x[j] +// +// After: +// X[S * 2]; // allocation +// for i in 0 to 1: // Prologue +// for j: +// x[j] = y[i, j] +// +// for i in 0 to N-1: // Main +// for j: +// x[j + (1 - i % 2) * S] = y[i + 1, j] +// for j: +// ... = x[j + (i % 2) * S] +// +// for i in N-1 to N: // Epilogue +// for j: +// ... = x[j + (i % 2) * S] +// +// Here, S is the original size of tensor x. +// +// The i loop is the double buffer loop of tensor x, where double +// buffering is applied to the tensor. The first step of lowering is +// to find the double buffering axis for each double buffered +// tensor. It must not be parallelized as it isn't possible to double +// buffer parallelized loops. Also, an unrolled axis expands the +// allocation and is intended to make the loop completely unrolled, +// which also conflicts with double buffering. So, basically, the double +// buffering axis is the inner-most axis within the axes left +// of the CA position. However, when it is parallelized or unrolled, a +// further left axis is picked. +// +// Once the double buffer axis is determined, the main task is to +// replicate the corresponding double buffer loop as illustrated +// above. The Prologue loop is to just fetch the first element to +// populate the buffer. The main loop is mostly the same as the +// original loop, except for the indexing change to switch the two +// buffers. When used as a consumer, an offset of (1 - i % 2) * S is +// added, whereas (i % 2) * S is added when used as a producer. Here, +// i is the index of the double buffer loop. The Epilogue loop is just +// for the last iteration of the loop. Since the main loop reads one +// element ahead of the producer of the double buffered tensor, it +// would require an additional guard to prevent buffer overruns with +// the producer if the main loop were also used for the last +// iteration. However, the value loaded by the invalid load would not +// be used, so instead of adding the additional predicate, the Epilogue +// loop is replicated from the original loop, except for the load +// expression since it's not used. Note that this overrun does not +// happen when the producer is on gmem, so in that case, this +// additional replication is not done. +// +// When creating those three types of loops, additional care must be +// taken when multiple tensors are double buffered. When multiple +// tensors use the same loop as their double buffer loop, one pass of +// replication takes care of them at once, meaning the same Prologue, +// Main, Epilogue loops are used for the multiple tensors. +// +// Other tasks to do for a double buffer tensor include: +// - Move allocation to outside of the double buffer loop +// - Double the allocation size +// - Omit the RAW sync in the Main and Epilogue loops + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +unsigned int getDoubleBufferAxisPosition(const TensorView* tv); + +IterDomain* getDoubleBufferAxis(const TensorView* tv); + +void validateDoubleBufferedTensor(const TensorView* tv); + +class TORCH_CUDA_CU_API DoubleBufferPass { + public: + //! Apply double buffering transformations + static std::vector run(const std::vector& exprs); +}; + +class TORCH_CUDA_CU_API DoubleBufferInfo { + // Lowering information of double buffered tensors. + struct TvInfo { + IterDomain* double_buffer_axis = nullptr; + Val* original_alloc_size = nullptr; + }; + + public: + void build(Fusion* fusion); + + void setDoubleBufferAxis(const TensorView* tv, IterDomain* id); + + IterDomain* getDoubleBufferAxis(const TensorView* tv); + + //! Get a loop that matches with a given double-buffer axis. If + //! ignore_prologue is true, a matched loop is ignored if it's a + //! prologue loop. + static kir::ForLoop* getDoubleBufferLoop( + IterDomain* axis, + const std::vector& loops, + bool ignore_prologue = false); + + //! Get a loop that matches with the double-buffer axis of a given + //! double-buffered tensor. If ignore_prologue is true, a matched + //! loop is ignored if it's a prologue loop. + kir::ForLoop* getDoubleBufferLoop( + const TensorView* tv, + const std::vector& loops, + bool ignore_prologue = false); + + void setOriginalAllocSize(const TensorView* tv, Val* size); + + Val* getOriginalAllocSize(const TensorView* tv); + + private: + TvInfo& getTvInfo(const TensorView* tv); + + private: + //! Keeps track of information for lowering double buffered tensors + std::unordered_map map_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp index 2353ea9bbf5..84c72c08185 100644 --- a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp @@ -541,7 +541,7 @@ ExprGroup* ExprSegmentationSorter::makeEmptyGroup() { ExprGroup* ExprSegmentationSorter::makeEmptyGroup(Expr* expr) { auto group = makeEmptyGroup(); group->exprs().push_back(expr); - if (ir_utils::isTVOp(expr)) { + if (ir_utils::isTvOp(expr)) { auto out_tv = expr->outputs()[0]->as(); // Grab all id's that are shared with other tensors. for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) { @@ -721,7 +721,7 @@ std::vector getLocalDomainOrdering( std::unordered_set domains; for (auto expr : exprs) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { continue; } diff --git a/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp new file mode 100644 index 00000000000..fa84d1006a1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp @@ -0,0 +1,119 @@ +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +// Replace trivial reductions with unary ops. +class TrivialReductionReplacement : private OptOutMutator { + public: + TrivialReductionReplacement( + Fusion* fusion, + const TrivialReductionInfo& trivial_reduction_info) + : trivial_reduction_info_(trivial_reduction_info) { + FusionGuard fg(fusion); + auto exprs = StmtSort::getExprs(fusion); + for (auto expr : exprs) { + mutate(expr); + } + } + + private: + using OptOutMutator::mutate; + void mutate(ReductionOp* rop) final { + if (rop->out()->isA()) { + auto out_tv = rop->out()->as(); + if (std::all_of( + out_tv->domain()->domain().begin(), + out_tv->domain()->domain().end(), + [&](IterDomain* id) { + // If id is a reduction axis, is it a trivial reduction? + if (id->isReduction()) { + return trivial_reduction_info_.isDerived(id); + } else { + return true; + } + })) { + auto out = rop->out(); + auto in = rop->in(); + auto container = out->container(); + removeExpr(container, rop); + IrBuilder::create(container, UnaryOpType::Set, out, in); + } + } + } + + const TrivialReductionInfo& trivial_reduction_info_; +}; + +// Replaces Transpose, Shift, Gather, and View Ops with Unary Ops. +class UnaryOpInserter : private kir::ExprMutator { + public: + static std::vector insert(const std::vector& exprs) { + UnaryOpInserter inserter(exprs); + return inserter.exprs_; + } + + private: + using kir::ExprMutator::handle; + + UnaryOpInserter(const std::vector& exprs) { + kir::ExprMutator::traverseAndInsert(exprs); + } + + void handle(TransposeOp* top) final { + auto out = top->out(); + auto in = top->in(); + auto container = out->container(); + registerReplace( + top, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } + + void handle(ShiftOp* sop) final { + auto out = sop->out(); + auto in = sop->in(); + auto container = out->container(); + registerReplace( + sop, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } + + void handle(GatherOp* gop) final { + auto out = gop->out(); + auto in = gop->in(); + auto container = out->container(); + registerReplace( + gop, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } + + void handle(ViewOp* vop) final { + auto out = vop->out(); + auto in = vop->in(); + auto container = out->container(); + registerReplace( + vop, IrBuilder::create(container, UnaryOpType::Set, out, in)); + } +}; + +} // namespace + +void trivialReductionReplacement( + Fusion* fusion, + const TrivialReductionInfo& trivial_reduction_info) { + TrivialReductionReplacement replacement(fusion, trivial_reduction_info); +} + +// Transpose, Shift, Gather, and View Ops with Unary Set Ops +std::vector unarySetOpInserter(const std::vector& exprs) { + return UnaryOpInserter::insert(exprs); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h new file mode 100644 index 00000000000..e18f4a8f077 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Replaces trivial reductions with Unary Set Ops +void trivialReductionReplacement(Fusion*, const TrivialReductionInfo&); + +// Transpose, Shift, Gather, and View Ops with Unary Set Ops +std::vector unarySetOpInserter(const std::vector& exprs); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index d92dd279b17..b0ef14079c4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include @@ -13,30 +12,24 @@ namespace jit { namespace fuser { namespace cuda { -IndexLowering::IndexLowering() : ir_builder_(GpuLower::current()->kernel()) {} - -kir::Val* IndexLowering::lowerSrcIndex(kir::Val* src, kir::Val* dst) const { - if (auto tv = dynamic_cast(src)) { - TORCH_INTERNAL_ASSERT(dst->isA()); - return Index::getProducerIndex( - tv->fuserTv(), - dst->as()->fuserTv(), - scope_utils::getLoops(active_scope_expr_)); +Val* IndexLowering::lowerSrcIndex(Val* src, Val* dst) const { + if (auto tv = dynamic_cast(src)) { + TORCH_INTERNAL_ASSERT(dst->isA()); + return Index::getProducerIndex(tv, dst->as(), for_loops_); } else { return src; } } -kir::Val* IndexLowering::lowerDstIndex(kir::Val* dst) const { - if (auto tv = dynamic_cast(dst)) { - return Index::getConsumerIndex( - tv->fuserTv(), scope_utils::getLoops(active_scope_expr_)); +Val* IndexLowering::lowerDstIndex(Val* dst) const { + if (auto tv = dynamic_cast(dst)) { + return Index::getConsumerIndex(tv, for_loops_); } else { return dst; } } -void IndexLowering::pushBack(kir::Expr* expr) { +void IndexLowering::pushBack(Expr* expr) { if (active_scope_ == nullptr) { lowered_exprs_.push_back(expr); } else { @@ -44,78 +37,71 @@ void IndexLowering::pushBack(kir::Expr* expr) { } } -void IndexLowering::visit(const kir::IfThenElse* ite) { - const auto prev_scope_expr = active_scope_expr_; +void IndexLowering::handle(const kir::IfThenElse* ite) { const auto prev_scope = active_scope_; - // TODO(kir): try to avoid recreating new nodes and leaving old ones around - auto new_ite = ir_builder_.create(ite->predicate()); + auto new_ite = IrBuilder::create(ite->predicate()); pushBack(new_ite); - active_scope_expr_ = new_ite; active_scope_ = &new_ite->thenBody(); for (auto expr : ite->thenBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } active_scope_ = &new_ite->elseBody(); for (auto expr : ite->elseBody().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } active_scope_ = prev_scope; - active_scope_expr_ = prev_scope_expr; } -void IndexLowering::visit(const kir::ForLoop* for_loop) { - const auto prev_scope_expr = active_scope_expr_; +void IndexLowering::handle(const kir::ForLoop* for_loop) { const auto prev_scope = active_scope_; - auto new_for_loop = ir_builder_.create(for_loop); + auto new_for_loop = IrBuilder::create(for_loop); pushBack(new_for_loop); - active_scope_expr_ = new_for_loop; active_scope_ = &new_for_loop->body(); + for_loops_.push_back(new_for_loop); for (auto expr : for_loop->body().exprs()) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } + for_loops_.pop_back(); active_scope_ = prev_scope; - active_scope_expr_ = prev_scope_expr; } -void IndexLowering::visit(const kir::UnaryOp* uop) { +void IndexLowering::handle(const UnaryOp* uop) { const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); - pushBack(ir_builder_.create(uop->operation(), out, in)); + pushBack(IrBuilder::create(uop->getUnaryOpType(), out, in)); } -void IndexLowering::visit(const kir::BinaryOp* bop) { +void IndexLowering::handle(const BinaryOp* bop) { const auto lhs = lowerSrcIndex(bop->lhs(), bop->out()); const auto rhs = lowerSrcIndex(bop->rhs(), bop->out()); const auto out = lowerDstIndex(bop->out()); - pushBack(ir_builder_.create(bop->operation(), out, lhs, rhs)); + pushBack(IrBuilder::create(bop->getBinaryOpType(), out, lhs, rhs)); } -void IndexLowering::visit(const kir::TernaryOp* top) { +void IndexLowering::handle(const TernaryOp* top) { const auto in1 = lowerSrcIndex(top->in1(), top->out()); const auto in2 = lowerSrcIndex(top->in2(), top->out()); const auto in3 = lowerSrcIndex(top->in3(), top->out()); const auto out = lowerDstIndex(top->out()); - pushBack( - ir_builder_.create(top->operation(), out, in1, in2, in3)); + pushBack(IrBuilder::create( + top->getTernaryOpType(), out, in1, in2, in3)); } namespace { // Get the size of the temporary work buffer for grid communication, this can be // grid reduction, broadcast, or grid welford. -kir::Val* getGridCommWorkBufferSize( - kir::IrBuilder& ir_builder, - const kir::TensorDomain* td) { +Val* getGridCommWorkBufferSize(const TensorDomain* td) { // The buffer size is the number of thread blocks multiplied by the // number of threads not used for reduction domains. // Note: Previously it was calculated based on the shape of the @@ -125,7 +111,7 @@ kir::Val* getGridCommWorkBufferSize( // size if the parallel dimensions are exact, but otherwise, just // computing the buffer size based on the tensor shape isn't // sufficient since there could be extra threads/blocks. - kir::Val* buffer_size = ir_builder.create(1); + Val* buffer_size = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeThreads) { auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); if (pt_dim == nullptr || pt_dim->isOneInt()) { @@ -133,33 +119,31 @@ kir::Val* getGridCommWorkBufferSize( } if (isParallelTypeThreadDim(pt) && std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) { - return out_id->parallelType() == pt && + return out_id->getParallelType() == pt && (out_id->isReduction() || out_id->isBroadcast()); })) { continue; } - buffer_size = ir_builder.mulExpr(buffer_size, pt_dim); + buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim); } return buffer_size; } -kir::Val* getGridSyncBufferSize( - kir::IrBuilder& ir_builder, - const kir::TensorDomain* td) { +Val* getGridSyncBufferSize(const TensorDomain* td) { // See the comment above for getGridCommWorkBufferSize. - kir::Val* buffer_size = ir_builder.create(1); + Val* buffer_size = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeBIDs) { auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); if (pt_dim == nullptr || pt_dim->isOneInt()) { continue; } if (std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) { - return out_id->parallelType() == pt && + return out_id->getParallelType() == pt && (out_id->isReduction() || out_id->isBroadcast()); })) { continue; } - buffer_size = ir_builder.mulExpr(buffer_size, pt_dim); + buffer_size = IrBuilder::mulExpr(buffer_size, pt_dim); } return buffer_size; } @@ -167,26 +151,25 @@ kir::Val* getGridSyncBufferSize( // Allocate global buffer for a grid communication calls, i.e. grid reduce, grid // welford reduce, grid broadcast. kir::Allocate* allocGlobalBufferForGridComm( - kir::IrBuilder& ir_builder, - kir::Val* buffer_size, + Val* buffer_size, DataType dtype, bool zero_init) { - const std::vector new_buffer_ids = { - ir_builder.create(ir_builder.zeroVal(), buffer_size)}; - const auto buffer_domain = - ir_builder.create(new_buffer_ids); - const auto buffer_tv = ir_builder.create( - dtype, buffer_domain, MemoryType::Global); - return ir_builder.create( - buffer_tv, buffer_tv->memoryType(), nullptr, zero_init); + const std::vector new_buffer_ids = { + IrBuilder::create( + GpuLower::current()->kernel()->zeroVal(), buffer_size)}; + const auto buffer_domain = IrBuilder::create(new_buffer_ids); + const auto buffer_tv = + IrBuilder::create(buffer_domain, dtype, MemoryType::Global); + return IrBuilder::create( + buffer_tv, buffer_tv->getMemoryType(), nullptr, zero_init); } } // namespace -void IndexLowering::visit(const kir::ReductionOp* rop) { - TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(rop)); +void IndexLowering::handle(const ReductionOp* rop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(rop)); - const auto out_tv = rop->out()->as(); + const auto out_tv = rop->out()->as(); const auto out_domain = out_tv->domain(); const bool is_block_reduce = out_domain->hasBlockReduction(); @@ -199,7 +182,7 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { std::none_of( out_domain->domain().begin(), out_domain->domain().end(), - [](kir::IterDomain* id) { + [](IterDomain* id) { return !id->isThread() && id->isReduction() && !id->extent()->isOneInt(); }), @@ -212,11 +195,11 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { const auto out = lowerDstIndex(rop->out()); const auto in = lowerSrcIndex(rop->in(), rop->out()); - kir::ReductionOp* block_reduction_op = nullptr; + ReductionOp* block_reduction_op = nullptr; if (is_block_reduce) { - block_reduction_op = ir_builder_.create( - rop->operation(), rop->init(), out, in); + block_reduction_op = IrBuilder::create( + rop->getReductionOpType(), rop->init(), out, in); if (rop->predicate()) { block_reduction_op->setPredicate(rop->predicate()); } @@ -228,29 +211,22 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { if (is_grid_reduce) { const auto reduce_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridCommWorkBufferSize(ir_builder_, out_domain), - out->dtype(), - false); + getGridCommWorkBufferSize(out_domain), out->dtype(), false); const auto sync_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridSyncBufferSize(ir_builder_, out_domain), - DataType::Int, - true); + getGridSyncBufferSize(out_domain), DataType::Int, true); const auto grid_reduction_op = (block_reduction_op == nullptr) - ? ir_builder_.create( - rop->operation(), rop->init(), out, in) + ? IrBuilder::create( + rop->getReductionOpType(), rop->init(), out, in) : block_reduction_op; // The thread predicate for GridReduction needs to be set // separately from the main predicate. Do not combine them like // other expressions. const auto& thread_pred = - GpuLower::current()->threadPredMap().getPredicatedParallelTypes( - out_tv->fuserTv()); - auto grid_reduction = ir_builder_.create( + GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); + auto grid_reduction = IrBuilder::create( grid_reduction_op, reduce_buffer, sync_buffer); grid_reduction->setThreadPredicate(thread_pred); @@ -260,8 +236,8 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { // predicate does not work when the write predicate of the // blockReduce is different from the read predicate. if (is_block_reduce) { - grid_reduction->setPredicate( - ir_builder_.create(ir_builder_.trueVal())); + grid_reduction->setPredicate(IrBuilder::create( + GpuLower::current()->kernel()->trueVal())); } else { grid_reduction->setPredicate(rop->predicate()); } @@ -277,15 +253,15 @@ void IndexLowering::visit(const kir::ReductionOp* rop) { } if (!is_block_reduce && !is_grid_reduce) { - // TODO(kir): this breaks our "SSA" form - pushBack(ir_builder_.create(rop->operation(), out, out, in)); + pushBack( + IrBuilder::create(rop->getReductionOpType(), out, out, in)); } } -void IndexLowering::visit(const kir::WelfordOp* wop) { - TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(wop)); +void IndexLowering::handle(const WelfordOp* wop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(wop)); - const auto out_tv = wop->outAvg()->as(); + const auto out_tv = wop->outAvg()->as(); const auto out_domain = out_tv->domain(); const bool is_block_reduce = out_domain->hasBlockReduction(); @@ -298,7 +274,7 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { std::none_of( out_domain->domain().begin(), out_domain->domain().end(), - [](kir::IterDomain* id) { + [](IterDomain* id) { return !id->isThread() && id->isReduction(); }), "Found a reduction stage that has both a non-parallelized ", @@ -322,18 +298,18 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { auto out_var = lowerDstIndex(wop->outVar()); auto out_N = lowerDstIndex(wop->outN()); - kir::WelfordOp* welford_op = ir_builder_.create( - out_var, + WelfordOp* welford_op = IrBuilder::create( out_avg, + out_var, out_N, - wop->initVar(), wop->initAvg(), + wop->initVar(), wop->initN(), - in_var, in_avg, + in_var, in_N); - kir::WelfordOp* block_welford_op = nullptr; + WelfordOp* block_welford_op = nullptr; if (is_block_reduce) { block_welford_op = welford_op; @@ -348,21 +324,17 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { if (is_grid_reduce) { // Buffer allocation - const auto work_buffer_size = - getGridCommWorkBufferSize(ir_builder_, out_domain); + const auto work_buffer_size = getGridCommWorkBufferSize(out_domain); - const auto out_var_buffer = allocGlobalBufferForGridComm( - ir_builder_, work_buffer_size, out_var->dtype(), false); - const auto out_avg_buffer = allocGlobalBufferForGridComm( - ir_builder_, work_buffer_size, out_avg->dtype(), false); - const auto out_N_buffer = allocGlobalBufferForGridComm( - ir_builder_, work_buffer_size, out_N->dtype(), false); + const auto out_var_buffer = + allocGlobalBufferForGridComm(work_buffer_size, out_var->dtype(), false); + const auto out_avg_buffer = + allocGlobalBufferForGridComm(work_buffer_size, out_avg->dtype(), false); + const auto out_N_buffer = + allocGlobalBufferForGridComm(work_buffer_size, out_N->dtype(), false); const auto sync_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridSyncBufferSize(ir_builder_, out_domain), - DataType::Int, - true); + getGridSyncBufferSize(out_domain), DataType::Int, true); // Grid Welford instantiation const auto grid_welford_op = @@ -372,10 +344,9 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { // separately from the main predicate. Do not combine them like // other expressions. const auto& thread_pred = - GpuLower::current()->threadPredMap().getPredicatedParallelTypes( - out_tv->fuserTv()); + GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); - auto grid_welford = ir_builder_.create( + auto grid_welford = IrBuilder::create( grid_welford_op, out_var_buffer, out_avg_buffer, @@ -400,18 +371,18 @@ void IndexLowering::visit(const kir::WelfordOp* wop) { } } -void IndexLowering::visit(const kir::BroadcastOp* bop) { - TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(bop)); +void IndexLowering::handle(const BroadcastOp* bop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(bop)); - const auto out_tv = bop->out()->as(); + const auto out_tv = bop->out()->as(); const auto out = lowerDstIndex(bop->out()); const auto in = lowerSrcIndex(bop->in(), bop->out()); - auto indexed_expr = ir_builder_.create(out, in); + auto indexed_expr = + IrBuilder::create(out, in, bop->getBroadcastDimFlags()); const ParallelTypeBitmap parallel_bitmap = - GpuLower::current()->threadPredMap().getParallelBroadcastDomains( - out_tv->fuserTv()); + GpuLower::current()->threadPredMap().getParallelBroadcastDomains(out_tv); const bool block_x = parallel_bitmap.get(ParallelType::BIDx); const bool block_y = parallel_bitmap.get(ParallelType::BIDy); @@ -430,18 +401,12 @@ void IndexLowering::visit(const kir::BroadcastOp* bop) { // Grid broadcast const auto out_domain = out_tv->domain(); const auto broadcast_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridCommWorkBufferSize(ir_builder_, out_domain), - out->dtype(), - false); + getGridCommWorkBufferSize(out_domain), out->dtype(), false); const auto sync_buffer = allocGlobalBufferForGridComm( - ir_builder_, - getGridSyncBufferSize(ir_builder_, out_domain), - DataType::Int, - true); + getGridSyncBufferSize(out_domain), DataType::Int, true); - auto grid_broadcast = ir_builder_.create( + auto grid_broadcast = IrBuilder::create( indexed_expr, broadcast_buffer, sync_buffer); if (bop->predicate()) { @@ -453,19 +418,19 @@ void IndexLowering::visit(const kir::BroadcastOp* bop) { pushBack(grid_broadcast); } -void IndexLowering::visit(const kir::Allocate* allocate) { +void IndexLowering::handle(const kir::Allocate* allocate) { // TODO(kir): remove the need for const_cast pushBack(const_cast(allocate)); // NOLINT } -void IndexLowering::visit(const kir::Sync* sync) { +void IndexLowering::handle(const kir::Sync* sync) { // TODO(kir): remove the need for const_cast pushBack(const_cast(sync)); // NOLINT } -void IndexLowering::generate(const std::vector& exprs) { +void IndexLowering::generate(const std::vector& exprs) { for (auto expr : exprs) { - expr->accept(this); + OptOutConstDispatch::handle(expr); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 5eb27c78f28..2f3af0061e1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -1,10 +1,10 @@ #pragma once -#include +#include #include #include -#include +#include #include #include @@ -14,10 +14,11 @@ namespace jit { namespace fuser { namespace cuda { -class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor { +// TODO: Replace with mutator as IndexLowering is replacing expr's with +// versions that are doing indexing +class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { public: - static std::vector getIndexedExprs( - std::vector incoming_exprs) { + static std::vector getIndexedExprs(std::vector incoming_exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::IndexLowering::getIndexedExprs"); IndexLowering il; il.generate(incoming_exprs); @@ -25,28 +26,29 @@ class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor { } private: - IndexLowering(); + IndexLowering() = default; - void pushBack(kir::Expr*); + void pushBack(Expr*); - void visit(const kir::ForLoop*) final; - void visit(const kir::IfThenElse*) final; - void visit(const kir::UnaryOp*) final; - void visit(const kir::BinaryOp*) final; - void visit(const kir::TernaryOp*) final; - void visit(const kir::ReductionOp*) final; - void visit(const kir::WelfordOp*) final; - void visit(const kir::BroadcastOp*) final; - void visit(const kir::Allocate*) final; - void visit(const kir::Sync*) final; + void handle(const UnaryOp*) final; + void handle(const BinaryOp*) final; + void handle(const TernaryOp*) final; + void handle(const ReductionOp*) final; + void handle(const WelfordOp*) final; + void handle(const BroadcastOp*) final; - void generate(const std::vector& exprs); + void handle(const kir::ForLoop*) final; + void handle(const kir::IfThenElse*) final; + void handle(const kir::Allocate*) final; + void handle(const kir::Sync*) final; - kir::Val* lowerSrcIndex(kir::Val* val, kir::Val* dst) const; - kir::Val* lowerDstIndex(kir::Val* dst) const; + void generate(const std::vector& exprs); + + Val* lowerSrcIndex(Val* val, Val* dst) const; + Val* lowerDstIndex(Val* dst) const; private: - std::vector lowered_exprs_; + std::vector lowered_exprs_; // This is a slight work around as scope has a couple definitions, we have the // Scope that's in ForLoop/IfThenElse which is really just a wrapper around @@ -55,9 +57,10 @@ class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor { // could be either the body or else body of the IfThenElse. However, we want // to understand the nesting of IfThenElse/ForLoop nodes. kir::Scope* active_scope_ = nullptr; - kir::Expr* active_scope_expr_ = nullptr; - kir::IrBuilder ir_builder_; + // Track for loops to send to indexing. Similar to what's done in + // kir::IrVisitor + std::vector for_loops_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 0947ef0f579..77be88183ec 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -1,8 +1,8 @@ #include #include +#include #include -#include -#include +#include #include #include @@ -33,8 +33,8 @@ class SmemAllocMap { public: //! Insert a new node if it's a SMEM allocation void insert(kir::Allocate* alloc) { - if (auto tv = dynamic_cast(alloc->buffer())) { - if (tv->memoryType() == MemoryType::Shared) { + if (auto tv = dynamic_cast(alloc->buffer())) { + if (tv->getMemoryType() == MemoryType::Shared) { // Note that a TensorView can have two allocations due to // unswitch. auto p = map_.insert({tv, alloc}); @@ -50,290 +50,298 @@ class SmemAllocMap { } } - //! Get the buffer that is actually allocated for a given TV - kir::TensorView* getRealBuffer(kir::TensorView* tv) const { + //! Run through aliases to get the buffer that is actually allocated for a + //! given TV + TensorView* getRealBuffer(TensorView* tv) const { auto it = map_.find(tv); TORCH_INTERNAL_ASSERT( - it != map_.end(), "Allocation not found for ", kir::toString(tv)); + it != map_.end(), "Allocation not found for ", tv->toString()); const kir::Allocate* alloc = it->second; while (alloc->alias()) { alloc = alloc->alias(); } auto buf = alloc->buffer(); - TORCH_INTERNAL_ASSERT(buf->isA()); - return buf->as(); + TORCH_INTERNAL_ASSERT(buf->isA()); + return buf->as(); } private: - std::unordered_map map_; + std::unordered_map map_; }; -//! Insert WAR sync for a given ForLoop -class LocalSyncInserterForLoop { - using TvSet = std::unordered_set; +struct WarMemoryInfo { + // True if there's a sync after the last read within the alloc loop. + bool sync_after_read = false; - public: - //! Insert Sync nodes at the end of a given for-loop when a WAR - //! hazard may happen. - LocalSyncInserterForLoop(kir::ForLoop* fl, SmemAllocMap& alloc_map) - : alloc_map_(alloc_map) { - for (auto expr : fl->body().exprs()) { - handle(expr); - } + // True if there's a sync before the first write. There can be multiple writes + // from memory aliasing. + bool sync_before_write = false; - // No need to insert sync when the loop is not actually generated - if (fl->iter_domain()->isThread() || fl->iter_domain()->isBroadcast()) { - return; - } - - // Determine if any smem TV is written to at beginning of the for-loop - // and whether that smem TV is read from at the end of the for-loop - // Insert new SyncThreads at end of for-loop to prevent WAR race condition - // - // TODO: replace __syncthreads with __threadfence for alias ops - // - if (detectIntersection(initial_, final_) && - !fl->body().exprs().back()->isA() && !is_last_op_sync_) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - fl->body().push_back(ir_builder.create(true)); - initial_sync_ = true; - is_last_op_sync_ = true; - final_.clear(); - } - } + // Has there been a read of this memory location + bool read_hit = false; - const auto& initial() const { - return initial_; - } + // Has there been *the* write to this memory location, assumes single write + // instruction (needs to be before conditionals added to code) + bool write_hit = false; - const auto& final() const { - return final_; - } + // For loop this TV is compute_at'ed in. + kir::ForLoop* ca_loop = nullptr; +}; - const auto& all_smem_inputs() const { - return all_smem_inputs_; +// To prevent shared memory from being over written before it is read, a +// synchronization point has to be inserted either between the allocation of an +// SMEM buffer and where we write into it, or after the buffer's last read +// before exiting the allocation's scope. +// +// e.g. +// for i: +// "alloc A" in shared memory - This is really marked by the compute_at point +// sync_loc_0 +// for j: +// sync_loc_1 +// for k: +// sync_loc_2 +// A = ... +// for k: +// ... = ... A +// for j: +// for k: +// ... = ... A +// sync_loc_3 +// sync_loc_4 +// sync_loc_5 +// +// All sync locations here provide valid protection that memory in A is finished +// being read before it is over written in the next iteration +// +// Insertion of sync threads will be done from the inner most position to the +// outer most. If a sync protecting the buffer is not already placed, the +// location prefered for the sync threads is the last possible position. One +// future optimization could be to not sync on the last iteration of the loop +// the sync is placed in. +class WarSyncInserter : private kir::ExprMutator { + public: + static std::vector insert(const std::vector& exprs) { + WarSyncInserter inserter(exprs); + return inserter.exprs_; } - const auto& all_smem_outputs() const { - return all_smem_outputs_; + private: + //! Insert Sync nodes at the end of a given for-loop when a WAR + //! hazard may happen. + WarSyncInserter(const std::vector& exprs) { + auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); + for (const auto& entry : lower_alloc_info_map) { + alloc_map_.insert(entry.first); + } + kir::ExprMutator::traverseAndInsert(exprs); } - void handle(kir::Expr* expr) { - if (ir_utils::isTVOp(expr)) { - is_last_op_sync_ = false; - - // For this SyncInserter - if (initial_sync_) { - addInputSmemTvs(expr, final_); - } else { - addInputSmemTvs(expr, final_); - addOutputSmemTvs(expr, initial_); + void handle(kir::IfThenElse* ite) final { + TORCH_INTERNAL_ASSERT( + ite->elseBody().empty(), + "Pass does not support conditional flow,", + " needs to be done before conditional execution is lowered."); + kir::ExprMutator::handle(ite); + } + + void handle(kir::Sync* sync) final { + // Register the sync for the active for loop + sync_hit_.back() = true; + // Run through the active allocations, if a read was hit, register there was + // a sync after the read. If there's subsequent reads on this buffer the + // sync_after_read will be cleared. + for (auto& entry : smem_allocations_) { + auto& alloc_stack = entry.second; + if (alloc_stack.back().read_hit) { + alloc_stack.back().sync_after_read = true; } - - // For parent SyncInserter - addOutputSmemTvs(expr, all_smem_outputs_); - addInputSmemTvs(expr, all_smem_inputs_); - } else if (auto sync = dynamic_cast(expr)) { - handle(sync); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto alloc = dynamic_cast(expr)) { - alloc_map_.insert(alloc); } } - void handle(kir::Sync* sync) { - is_last_op_sync_ = true; - initial_sync_ = true; - final_.clear(); - } - - void handle(kir::IfThenElse* ite) { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); + // Checks if fl or loops within it have hit a sync + bool syncWithin(kir::ForLoop* fl) { + // If outer most scope check the first sync_hit_ position + if (fl == nullptr) { + return sync_hit_[0]; } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); - } - } - void handle(kir::ForLoop* fl) { - LocalSyncInserterForLoop child_sync_inserter(fl, alloc_map_); - - const auto& child_inputs = child_sync_inserter.all_smem_inputs(); - const auto& child_outputs = child_sync_inserter.all_smem_outputs(); - const bool maybe_skipped = !fl->start()->isZeroInt() && - !isParallelTypeThread(fl->iter_domain()->parallelType()); - - // Default - Track all smem inputs / outputs - all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end()); - all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end()); - - // Propagate the last_op_sync flag from the child loop. If the - // child is deterministically executed at least once, just set the - // flag with the child flag. Otherwise, conservatively set the - // flag, i.e., if the current flag is true and the child flag is - // also true, we can say the last op is still sync. - if (!maybe_skipped) { - is_last_op_sync_ = child_sync_inserter.is_last_op_sync_; - } else { - is_last_op_sync_ = - is_last_op_sync_ && child_sync_inserter.is_last_op_sync_; - } + // Find the for loop we want to look within + auto fl_it = std::find(for_loops_.begin(), for_loops_.end(), fl); - // When the child is not guaranteed to have sync. - if (!child_sync_inserter.initial_sync_) { - // If no sync is yet found, add the child outputs to - // initial. - if (!initial_sync_) { - initial_.insert(child_outputs.begin(), child_outputs.end()); - } - // Add the child inputs to final even when inital_sync is false, - // which only means sync may not be found yet. - final_.insert(child_inputs.begin(), child_inputs.end()); - } else { - // Similar to the above case, but here, the child is guaranteed - // to have sync, so we only need to look at initial and final. - if (!initial_sync_) { - initial_.insert( - child_sync_inserter.initial().begin(), - child_sync_inserter.initial().end()); - } - if (!maybe_skipped) { - initial_sync_ = true; - final_.clear(); - } - final_.insert( - child_sync_inserter.final().begin(), - child_sync_inserter.final().end()); - } - } + // Convert it to an index, but add one for the outer most scope + auto fl_i = std::distance(for_loops_.begin(), fl_it) + 1; - static bool detectIntersection(const TvSet& left, const TvSet& right) { - for (auto item : left) { - if (right.find(item) != right.end()) { + // Start at that index and see if there's syncs within that for loop + for (auto i : c10::irange(fl_i, sync_hit_.size())) { + if (sync_hit_[i]) { return true; } } return false; } - void addOutputSmemTvs(const kir::Expr* expr, TvSet& set) { - for (auto out : expr->outputs()) { - if (auto tv = dynamic_cast(out)) { - if (tv->memoryType() == MemoryType::Shared) { - auto real_tv = alloc_map_.getRealBuffer(tv); - set.insert(real_tv); - } + void handle(Expr* expr) final { + // If not a tensor view expression continue with dispatch + if (!ir_utils::isTvOp(expr)) { + kir::ExprMutator::handle(expr); + return; + } + + // Mark write has been hit for all output tvs + auto out_tvs = ir_utils::filterByType(expr->outputs()); + for (auto out_tv : out_tvs) { + if (out_tv->getMemoryType() != MemoryType::Shared) { + continue; } + auto& entry = getMemInfo(out_tv); + + // If this is the first write and there's a sync in one of the loops after + // the compute at loop, then this buffer is protected. + if (syncWithin(entry.ca_loop) && !entry.write_hit) { + entry.sync_before_write = true; + } + entry.write_hit = true; } - } - void addInputSmemTvs(const kir::Expr* expr, TvSet& set) { - for (auto in : expr->inputs()) { - if (auto tv = dynamic_cast(in)) { - if (tv->memoryType() == MemoryType::Shared) { - auto real_tv = alloc_map_.getRealBuffer(tv); - set.insert(real_tv); - } + // Mark read was hit, if sync_after_read was set, clear it. + auto inp_tvs = ir_utils::filterByType(expr->inputs()); + for (auto inp_tv : inp_tvs) { + if (inp_tv->getMemoryType() != MemoryType::Shared) { + continue; } + auto& entry = getMemInfo(inp_tv); + entry.read_hit = true; + // Clear the sync_after_read if it was set because there was another write + entry.sync_after_read = false; } } - private: - //! Allocation map of SMEM buffers - SmemAllocMap& alloc_map_; - - //! Track Shared Memory Inputs (Reads) for parent for-loop - TvSet all_smem_inputs_; - - //! Track Shared Memory Outputs (Writes) for parent for-loop - TvSet all_smem_outputs_; - - //! Shared Memory Writes at beginning of the for-loop - //! before first SyncThreads - TvSet initial_; + void handle(kir::ForLoop* for_loop) final { + // Push loop scope information + auto prev_within_iter_loop_ = within_iter_loop_; + sync_hit_.push_back(false); - //! Shared Memory Reads at end of the for-loop - //! Cleared after each SyncThreads - TvSet final_; + // If there is no real iterating loop WAR syncs aren't necessary + within_iter_loop_ = within_iter_loop_ || + !(for_loop->iter_domain()->isThread() || + for_loop->iter_domain()->isBroadcast() || + for_loop->iter_domain()->extent()->isOneInt()); - //! Track first sync deterministically found in for-loop. Even when a - //! child loop has a sync, if it may not be executed due to non-zero - //! start value, this flag remains false. - bool initial_sync_ = false; + // Process the expressions in the for loop + kir::ExprMutator::handle(for_loop); - //! Track if last op is sync - bool is_last_op_sync_ = false; -}; - -class LocalSyncInserter { - public: - //! Write-After-Read race conditions are only found within for-loops. - //! Sync nodes are inserted directly into the for-loops. - //! The expressions are modified in-place and exprs is const. - static void insertSyncs(const std::vector& exprs) { - LocalSyncInserter inserter; - inserter.insert(exprs); - } + // Sync analysis and cleanup: + // + // Pop for loop stack inside WarMemoryInfo structs if they match this one. + // Erase empty entries so we don't continue to search over them + // + // Insert sync at end of this for loop if any of the entries require + std::vector to_erase; + bool insert_sync = false; + for (auto& entry : smem_allocations_) { + auto& alloc_stack = entry.second; + if (alloc_stack.size() && alloc_stack.back().ca_loop == for_loop) { + if (!alloc_stack.back().sync_after_read && + !alloc_stack.back().sync_before_write) { + insert_sync = within_iter_loop_; + } - private: - void insert(const std::vector& exprs) { - for (auto expr : exprs) { - if (auto fl = dynamic_cast(expr)) { - LocalSyncInserterForLoop sync_inserter(fl, alloc_map_); - } else if (auto ite = dynamic_cast(expr)) { - insert(ite->thenBody().exprs()); - insert(ite->elseBody().exprs()); - } else if (auto alloc = dynamic_cast(expr)) { - alloc_map_.insert(alloc); + alloc_stack.pop_back(); + if (alloc_stack.empty()) { + to_erase.push_back(entry.first); + } } } - } - private: + for (auto tv : to_erase) { + smem_allocations_.erase(tv); + } + + // WAR Sync is necessary in this loop, register its insertion. + if (insert_sync) { + auto sync_expr = IrBuilder::create(true); + kir::ExprMutator::registerInsertAfter( + for_loop->body().exprs().back(), sync_expr, &for_loop->body()); + handle(sync_expr); + } + + // Pop for loop scope information + sync_hit_.pop_back(); + within_iter_loop_ = prev_within_iter_loop_; + } + + // Create a new WarMemoryInfo entry if required and return a reference to it, + // else return the WarMemoryInfo associated with tv + WarMemoryInfo& getMemInfo(TensorView* tv) { + auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv); + auto alloc_it = smem_allocations_.find(maybe_aliased_tv); + auto ca_loop = + loop_utils::getAllocInformation(tv, for_loops_).init_for_loop; + if (alloc_it == smem_allocations_.end()) { + WarMemoryInfo mem_info; + mem_info.ca_loop = ca_loop; + auto entry_it = + smem_allocations_ + .insert(std::make_pair( + maybe_aliased_tv, std::vector({mem_info}))) + .first; + return entry_it->second.back(); + } else if ( + maybe_aliased_tv != tv && alloc_it->second.back().ca_loop != ca_loop) { + WarMemoryInfo mem_info; + mem_info.ca_loop = ca_loop; + auto& alloc_stack = alloc_it->second; + alloc_stack.push_back(mem_info); + return alloc_stack.back(); + } + return alloc_it->second.back(); + } + + //! Allocation map of SMEM buffers. Needed because of SMEM buffer aliasing, + //! need to track the root of the alias to properly insert WAR hazard syncs SmemAllocMap alloc_map_; + + //! Is there a loop nest that has a non-trivial iteration (extent != 1) and + //! not bound to a block/thread. This indicates if a WAR sync is necessary, + //! otherwise the Expr is not in an iterating for loop. + bool within_iter_loop_ = false; + + // Track which loops have hit a sync. Used to see if there's a sync before + // write. + std::vector sync_hit_ = {false}; + + // Keep track of the active allocations we need to protect. Key is the + // "getRealBuffer", not the raw tv. There can be multiple WarMemoryInfo's + // because of aliasing. If the "getRealBuffer" tv has a compute at outside the + // alias tv, each aliased tv in a unique ca_loop has to be tracked separately + // for WAR insertion. + std::unordered_map> smem_allocations_; }; class ExprFlattener : private kir::IrVisitor { private: - void handle(kir::Expr* expr) { + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { if (expr->isA() || expr->isA()) { - expr->accept(this); + kir::IrVisitor::handle(expr); } else { - exprs_.push_back(expr); - } - } - - void visit(const kir::ForLoop* fl) final { - for (auto expr : fl->body().exprs()) { - handle(expr); - } - } - - void visit(const kir::IfThenElse* ite) final { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); + flat_exprs_.push_back(expr); } } private: - std::vector exprs_; + std::vector flat_exprs_; public: //! Flattens scopes extracting out a single ordered list of exprs. - static std::vector flatten( - const std::vector& loop_nests) { + static std::vector flatten(const std::vector& loop_nests) { ExprFlattener flattener; for (auto expr : loop_nests) { flattener.handle(expr); } - return flattener.exprs_; + return flattener.flat_exprs_; } }; @@ -342,53 +350,42 @@ class ValidatePlacementAfterWrites : private kir::IrVisitor { //! Validate no expr in writes found under loop static void validate( kir::ForLoop* loop, - const std::unordered_set& writes) { + const std::unordered_set& writes) { ValidatePlacementAfterWrites validator(writes); validator.handle(loop); } private: - ValidatePlacementAfterWrites(const std::unordered_set& writes) + using kir::IrVisitor::handle; + + ValidatePlacementAfterWrites(const std::unordered_set& writes) : writes_(writes) {} - void handle(kir::Expr* expr) { + void handle(Expr* expr) final { if (expr->isA() || expr->isA()) { - expr->accept(this); + kir::IrVisitor::handle(expr); } else { TORCH_INTERNAL_ASSERT( writes_.find(expr) == writes_.end(), "Block sync must be placed after ", - kir::toString(expr)); - } - } - - void visit(const kir::ForLoop* fl) final { - for (auto expr : fl->body().exprs()) { - handle(expr); - } - } - - void visit(const kir::IfThenElse* ite) final { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); + expr->toString()); } } private: - const std::unordered_set& writes_; + const std::unordered_set& writes_; }; -class ReadAfterWriteSyncs : public kir::MutableIrVisitor { +class ReadAfterWriteSyncs : public kir::ExprMutator { private: + using kir::ExprMutator::handle; + //! Traverse up the loop stack from loops_it and if a halo loop is //! found, place a given sync expr before the outer-most halo loop. bool insertBeforeHaloLoop( std::vector::iterator loops_it, kir::Sync* sync_expr, - const std::unordered_set& writes) { + const std::unordered_set& writes) { std::vector::iterator halo_loop_it; bool halo_loop_found = false; @@ -420,21 +417,21 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { if (halo_loop_it == for_loops_.begin()) { // place in global scope - auto place_before_it = - std::find(loop_nests_.begin(), loop_nests_.end(), halo_loop); - TORCH_INTERNAL_ASSERT(place_before_it != loop_nests_.end()); - loop_nests_.insert(place_before_it, sync_expr); + auto place_before_it = std::find(exprs_.begin(), exprs_.end(), halo_loop); + TORCH_INTERNAL_ASSERT(place_before_it != exprs_.end()); + exprs_.insert(place_before_it, sync_expr); } else { auto place_in = *(halo_loop_it - 1); - place_in->body().insert_before(halo_loop, sync_expr); + kir::ExprMutator::registerInsertBefore( + halo_loop, sync_expr, &place_in->body()); } return true; } - void handle(kir::Expr* expr) { - if (!ir_utils::isTVOp(expr) || expr->isA()) { - expr->accept(this); + void handle(Expr* expr) final { + if (!ir_utils::isTvOp(expr) || expr->isA()) { + kir::ExprMutator::handle(expr); return; } @@ -443,8 +440,8 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { auto last_writes = last_writes_.front(); last_writes_.pop_front(); // Found that a sync is needed - TORCH_INTERNAL_ASSERT(expr->outputs()[0]->isA()); - auto out_tv = expr->outputs()[0]->as(); + TORCH_INTERNAL_ASSERT(expr->outputs()[0]->isA()); + auto out_tv = expr->outputs()[0]->as(); // Find where a sync needs to be inserted // This is very similar to how allocations are placed, simply place sync @@ -454,39 +451,35 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { // out of or saving state for tensor view ID -> for loop // TODO: Explicitly test the 3 cases below - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto sync_expr = ir_builder.create(); - if (out_tv->fuserTv()->getComputeAtPosition() == 0) { + auto sync_expr = IrBuilder::create(); + if (out_tv->getComputeAtPosition() == 0) { // Sync should be placed at global scope, after its outer most loop if // it has one. - kir::Expr* place_after = for_loops_.size() > 0 ? for_loops_[0] : expr; - // Find location in loop_nests_ + Expr* place_after = for_loops_.size() > 0 ? for_loops_[0] : expr; + // Find location in exprs_ auto place_after_it = - std::find(loop_nests_.begin(), loop_nests_.end(), place_after); + std::find(exprs_.begin(), exprs_.end(), place_after); TORCH_INTERNAL_ASSERT( - place_after_it != loop_nests_.end(), + place_after_it != exprs_.end(), "Could not figure out where to place synchronization. ", "Tried to place after, ", - toString(place_after), + place_after->toString(), ", but could not find this expression at the global scope."); - loop_nests_.insert(place_after_it + 1, sync_expr); + + registerInsertAfter(*(place_after_it + 1), sync_expr, nullptr); } else { // Find the last loop in computeAt of out_tv, this is the loop where we // would place an allocation for out_tv - auto fuser_tv = out_tv->fuserTv(); - auto lowered_local_id = - GpuLower::current() - ->lowerValue(fuser_tv->axis( - (int)out_tv->fuserTv()->getComputeAtPosition() - 1)) - ->as(); + auto local_id = out_tv->axis((int)out_tv->getComputeAtPosition() - 1); auto loops_it = std::find_if( for_loops_.begin(), for_loops_.end(), - [&lowered_local_id](const auto& loop) { + [&local_id](const auto& loop) { return GpuLower::current()->caLoopMap().areMapped( - loop->iter_domain(), lowered_local_id) || - loop->iter_domain()->parallelType() == ParallelType::Unroll; + loop->iter_domain(), local_id) || + loop->iter_domain()->getParallelType() == + ParallelType::Unroll; }); TORCH_INTERNAL_ASSERT(loops_it != for_loops_.end()); @@ -497,7 +490,7 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { } auto place_in = *loops_it; - kir::Expr* place_after = nullptr; + Expr* place_after = nullptr; if (loops_it + 1 == for_loops_.end()) { // Inline allocation, place after expr @@ -509,22 +502,12 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { place_after = *(loops_it + 1); } - place_in->body().insert_after(place_after, sync_expr); + registerInsertAfter(place_after, sync_expr, &place_in->body()); } } } - void visit(kir::ForLoop* fl) final { - for_loops_.push_back(fl); - // Modifying in place, make a copy of the vector - const std::vector exprs = fl->body().exprs(); - for (auto expr : exprs) { - handle(expr); - } - for_loops_.pop_back(); - } - - void visit(kir::IfThenElse*) final { + void handle(kir::IfThenElse*) final { TORCH_INTERNAL_ASSERT( false, "Pass does not support conditional statements, ", @@ -532,18 +515,17 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { } // Clear the modify status for all shared memory buffers - static void cleanSharedMemory( - std::unordered_map& smem) { + static void cleanSharedMemory(std::unordered_map& smem) { smem.clear(); } // Return a set of expressions that modify shared-memory // tensors. Expressions are excluded when syncthreads are already // placed. - std::unordered_set isModifiedSharedMemory( - const std::unordered_map& smem, - const std::vector& tvs) const { - std::unordered_set last_writes; + std::unordered_set isModifiedSharedMemory( + const std::unordered_map& smem, + const std::vector& tvs) const { + std::unordered_set last_writes; for (auto tv : tvs) { auto it = smem.find(tv); if (it != smem.end()) { @@ -553,18 +535,17 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { return last_writes; } - ReadAfterWriteSyncs(std::vector _loop_nests) - : loop_nests_(std::move(_loop_nests)) { + ReadAfterWriteSyncs(const std::vector& _exprs) { // Fusion shared_memory values // Tracks if shared memory is modified - std::unordered_map smem; + std::unordered_map smem; // Flatten all the expressions - auto flattened_exprs = ExprFlattener::flatten(loop_nests_); + auto flattened_exprs = ExprFlattener::flatten(_exprs); - kir::Expr* prev_tv_expr = nullptr; + Expr* prev_tv_expr = nullptr; for (auto expr : flattened_exprs) { - if (!ir_utils::isTVOp(expr) || expr->isA()) { + if (!ir_utils::isTvOp(expr) || expr->isA()) { continue; } @@ -578,22 +559,20 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { cleanSharedMemory(smem); } - for (auto out : expr->outputs()) { - if (out->isA()) { - if (out->as()->memoryType() == MemoryType::Shared) { - smem[out] = expr; - } + for (auto tv : ir_utils::filterByType(expr->outputs())) { + // Double buffered tensors do not need RAW sync to be inserted + // here, except for the initial load part, which is taken care + // separately by DoubleBufferInserter. + if (tv->getMemoryType() == MemoryType::Shared && + !tv->isDoubleBuffered()) { + smem[tv] = expr; } } prev_tv_expr = expr; } - // Insert read after write syncs - const std::vector exprs = loop_nests_; - for (auto expr : exprs) { - handle(expr); - } + kir::ExprMutator::traverseAndInsert(_exprs); TORCH_INTERNAL_ASSERT( sync_after_.empty(), "Didn't place all required syncs."); @@ -601,7 +580,7 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { private: //! Keep track of expressions that must be followed by syncthreads - std::deque sync_after_; + std::deque sync_after_; //! Keep track of write expressions that must be placed before //! syncthreads. @@ -611,35 +590,27 @@ class ReadAfterWriteSyncs : public kir::MutableIrVisitor { //! be placed before that. last_writes_ keeps track of expressions //! modifying the smem buffer each syncthreads is used for so that //! it is not placed before those write expressions. - std::deque> last_writes_; - - //! Keep track of for loops while inserting syncthreads - std::vector for_loops_; - - //! Loop-nests where syncthreads are inserted - std::vector loop_nests_; + std::deque> last_writes_; public: - static std::vector insert( - const std::vector& loop_nests) { + static std::vector insert(const std::vector& loop_nests) { ReadAfterWriteSyncs inserter(loop_nests); - return inserter.loop_nests_; + return inserter.exprs_; } }; } // namespace -std::vector insertRawThreadSynchronization( - const std::vector& exprs) { +std::vector insertRawThreadSynchronization( + const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertRawThreadSynchronization"); return ReadAfterWriteSyncs::insert(exprs); } -std::vector insertWarThreadSynchronization( - const std::vector& exprs) { +std::vector insertWarThreadSynchronization( + const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertWarThreadSynchronization"); - LocalSyncInserter::insertSyncs(exprs); - return exprs; + return WarSyncInserter::insert(exprs); } } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h index 50618373448..756462f0bd7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -16,40 +16,14 @@ namespace cuda { //! //! WAR race condition occurs when the next iteration of the loop overwrites //! shared memory value before a previous operation has finished reading it. -//! -//! WAR Race Check: -//! Track all output shared memory TVs before first sync -//! Track all input shared memory TVs after last sync -//! If the intersection is non-empty, then there is a WAR race condition. -//! Recursively check each nested for-loop -//! -//! Parent-Child For-Loop Recursive Relationship -//! Notation: -//! None - Zero Syncs -//! 1+ - One or more Syncs -//! End - Sync is last op in for-loop to prevent WAR race condition -//! -//! Default: Track all shared memory inputs and outputs -//! -//! Parent - None -//! Child - None => Append All Child Outputs to Parent Initial -//! Child - 1+ => Parent first sync => Inherit Child Initial + Final -//! Child - End => Parent first sync => Keep Child Initial / Clear Parent Final -//! -//! Parent - 1+ -//! Child - None => Append All Child to Parent Last -//! Child - 1+ => Child Final to Parent Final / Discard Child Initial -//! Child - End => Clear Parent Last / Discard Child Initial -//! -//! If Child - End and Parent has zero remaining operations, then -//! Parent inherits Child End. -//! -std::vector insertWarThreadSynchronization( - const std::vector& exprs); +std::vector insertWarThreadSynchronization( + const std::vector& exprs); //! Insert syncs between writing to shared memory and then reading it. -std::vector insertRawThreadSynchronization( - const std::vector& exprs); +//! RAW pass is run before indexing, unrolling (loop duplication), memory +//! aliasing, and index (grid/block bcast/reduction) +std::vector insertRawThreadSynchronization( + const std::vector& exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index e4396f9a864..12c7d33e077 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -19,7 +18,7 @@ namespace jit { namespace fuser { namespace cuda { -std::vector LoopNestGenerator::loweredExprs( +std::vector LoopNestGenerator::loweredExprs( const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::LoopNestGenerator::loweredExprs"); TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); @@ -33,22 +32,20 @@ LoopNestGenerator::LoopNestGenerator(const std::vector& exprs) { namespace { -kir::ForLoop* openForHelper(kir::ForLoop* scope, kir::IterDomain* kir_id) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto extent_with_halo = gpu_lower->haloInfo().getExtent(kir_id); +kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { + auto extent_with_halo = GpuLower::current()->haloInfo().getExtent(id); kir::ForLoop* new_scope = nullptr; if (extent_with_halo) { // When an axis is extended with halo, unrolling and vectorization // are assumed to not be used for now. TORCH_INTERNAL_ASSERT( - kir_id->parallelType() != ParallelType::Unroll && - !isParallelTypeVectorize(kir_id->parallelType())); + id->getParallelType() != ParallelType::Unroll && + !isParallelTypeVectorize(id->getParallelType())); // Use the extent that's extended by halo - new_scope = ir_builder.create( - kir_id, - kir_id->isBroadcast() ? ir_builder.zeroVal() - : ir_builder.create(c10::nullopt), + new_scope = IrBuilder::create( + id, + id->isBroadcast() ? GpuLower::current()->kernel()->zeroVal() + : IrBuilder::create(c10::nullopt), nullptr, extent_with_halo, nullptr, @@ -56,7 +53,7 @@ kir::ForLoop* openForHelper(kir::ForLoop* scope, kir::IterDomain* kir_id) { nullptr, false); } else { - new_scope = ir_builder.create(kir_id); + new_scope = IrBuilder::create(id); } if (scope != nullptr) { scope->body().insert(0, new_scope); @@ -66,13 +63,13 @@ kir::ForLoop* openForHelper(kir::ForLoop* scope, kir::IterDomain* kir_id) { } // namespace -void LoopNestGenerator::openFor(kir::IterDomain* kir_iter_domain) { +void LoopNestGenerator::openFor(IterDomain* id) { if (for_loops_.size() > 0) { - const auto new_scope = openForHelper(for_loops_.back(), kir_iter_domain); + const auto new_scope = openForHelper(for_loops_.back(), id); // for_loop_allocations_.insert({new_scope, 0}); for_loops_.push_back(new_scope); } else { - for_loops_.push_back(openForHelper(nullptr, kir_iter_domain)); + for_loops_.push_back(openForHelper(nullptr, id)); lowered_exprs_.insert(lowered_exprs_.begin(), for_loops_.back()); } } @@ -82,7 +79,7 @@ void LoopNestGenerator::closeFor() { for_loops_.pop_back(); } -void LoopNestGenerator::pushFront(kir::Expr* expr) { +void LoopNestGenerator::pushFront(Expr* expr) { if (for_loops_.size() == 0) { lowered_exprs_.insert(lowered_exprs_.begin(), expr); } else { @@ -91,18 +88,15 @@ void LoopNestGenerator::pushFront(kir::Expr* expr) { } void LoopNestGenerator::handle(Expr* expr) { - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - // Check if it's a tensor view expression we need to place in the loop nest // structure - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { // Close all the loops, scalar operations cannot be inside for loops based // on expr sorting. while (!for_loops_.empty()) { closeFor(); } - pushFront(gpu_lower->lowerExpr(expr)); + pushFront(expr); for (auto out : expr->outputs()) { TORCH_INTERNAL_ASSERT( @@ -112,10 +106,8 @@ void LoopNestGenerator::handle(Expr* expr) { " cannot lower ", out->getValType().value()); - pushFront(ir_builder.create( - gpu_lower->lowerValue(out), - MemoryType::Local, - ir_builder.create(1))); + pushFront(IrBuilder::create( + out, MemoryType::Local, GpuLower::current()->kernel()->oneVal())); } return; } @@ -130,27 +122,19 @@ void LoopNestGenerator::handle(Expr* expr) { // Figure out what the entire loop structure should look like. std::vector loop_structure = loop_structures_.at(out_tv); - std::vector kir_loop_structure; - std::transform( - loop_structure.begin(), - loop_structure.end(), - std::back_inserter(kir_loop_structure), - [&gpu_lower](IterDomain* id) { - return gpu_lower->lowerValue(id)->as(); - }); // Ordering of loop_structure is global, so simply close loops we don't need, // and open the ones we do. while (!for_loops_.empty() && std::find( - kir_loop_structure.begin(), - kir_loop_structure.end(), - for_loops_.back()->iter_domain()) == kir_loop_structure.end()) { + loop_structure.begin(), + loop_structure.end(), + for_loops_.back()->iter_domain()) == loop_structure.end()) { closeFor(); } - for (auto loop : kir_loop_structure) { + for (auto loop : loop_structure) { auto find_it = std::find_if( for_loops_.begin(), for_loops_.end(), [loop](kir::ForLoop* fl) { return fl->iter_domain() == loop; @@ -160,7 +144,7 @@ void LoopNestGenerator::handle(Expr* expr) { } } - pushFront(gpu_lower->lowerExpr(expr)); + pushFront(expr); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index fbbdf079e89..9b480d7eb6f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -1,13 +1,12 @@ #pragma once -#include +#include #include #include #include #include -#include #include namespace torch { @@ -30,20 +29,20 @@ namespace cuda { //! nests to initialize reduction buffers. class TORCH_CUDA_CU_API LoopNestGenerator { public: - static std::vector loweredExprs(const std::vector& exprs); + static std::vector loweredExprs(const std::vector& exprs); private: LoopNestGenerator(const std::vector& exprs); // Open a new inner most for loop, track which TV it was constructed from // according to the computeAt chain. - void openFor(kir::IterDomain*); + void openFor(IterDomain*); // Close the inner most for loop void closeFor(); // Appends an expression to the current scope - void pushFront(kir::Expr* expr); + void pushFront(Expr* expr); void handle(Expr* expr); @@ -52,7 +51,7 @@ class TORCH_CUDA_CU_API LoopNestGenerator { private: // Lowered exprs to return - std::vector lowered_exprs_; + std::vector lowered_exprs_; // Keep all for loops conveniently to make unrolling easier, basically just a // stack of the active for_loops diff --git a/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp b/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp index f5f5c72676a..f17f91806d6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include namespace torch { @@ -12,11 +12,11 @@ namespace cuda { namespace { -class MagicZeroInserter : public kir::MutableIrVisitor { +class MagicZeroInserter : public kir::ExprMutator { public: - static std::vector insert(const std::vector& exprs) { + static std::vector insert(const std::vector& exprs) { MagicZeroInserter inserter(exprs); - return inserter.loop_nests_; + return inserter.exprs_; } private: @@ -25,94 +25,43 @@ class MagicZeroInserter : public kir::MutableIrVisitor { kir::ForLoop* fl = nullptr; }; - MagicZeroInserter(const std::vector& exprs) - : loop_nests_(exprs), ir_builder(GpuLower::current()->kernel()) { - loop_nests_.insert( - loop_nests_.begin(), ir_builder.create()); - for (auto expr : exprs) { - handle(expr); - } - insertAll(); - } - - void handle(kir::Expr* expr) { - if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } - } - - void handle(kir::IfThenElse* ite) { - scope_nest_.push_back(&ite->thenBody()); - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - scope_nest_.pop_back(); - scope_nest_.push_back(&ite->elseBody()); - for (auto expr : ite->elseBody().exprs()) { - handle(expr); - } - scope_nest_.pop_back(); + MagicZeroInserter(const std::vector& exprs) { + TORCH_INTERNAL_ASSERT(exprs.size()); + kir::ExprMutator::registerInsertBefore( + exprs.front(), IrBuilder::create(), nullptr); + kir::ExprMutator::traverseAndInsert(exprs); } - void handle(kir::ForLoop* fl) { + void handle(kir::ForLoop* fl) final { if (fl->isUnrolled()) { - kir::Scope* scope = nullptr; - if (!scope_nest_.empty()) { - scope = scope_nest_.back(); - } - insertion_list_.push_back({scope, fl}); - } else { - scope_nest_.push_back(&fl->body()); - for (auto expr : fl->body().exprs()) { - handle(expr); - } - scope_nest_.pop_back(); - } - } - - void insertAll() { - for (const auto& info : insertion_list_) { - auto fl = info.fl; - auto scope = info.scope; - if (scope == nullptr) { - // place in global scope - auto loop_it = std::find(loop_nests_.begin(), loop_nests_.end(), fl); - TORCH_INTERNAL_ASSERT(loop_it != loop_nests_.end()); - // Place after the loop - loop_it++; - loop_nests_.insert(loop_it, ir_builder.create()); + if (scope_.empty()) { + kir::ExprMutator::registerInsertAfter( + fl, IrBuilder::create()); } else { - scope->insert_after(fl, ir_builder.create()); + TORCH_INTERNAL_ASSERT( + scope_.back()->exprs().size(), "Not expecting an empty loop."); + kir::ExprMutator::registerInsertAfter( + fl, IrBuilder::create(), scope_.back()); } + } else { + kir::ExprMutator::handle(fl); } } - //! Keep track for loop structure - std::vector scope_nest_; - - // Keep a copy of the expressions provided - std::vector loop_nests_; - - kir::IrBuilder ir_builder; - std::vector insertion_list_; }; } // namespace -std::vector insertMagicZero(const std::vector& exprs) { +std::vector insertMagicZero(const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::insertMagicZero"); // Check if magic zero was even used, if not we don't have to define it or // update it. const auto gpu_lower = GpuLower::current(); auto kernel = gpu_lower->kernel(); - const bool has_magic_zero = std::any_of( - kernel->irNodes().begin(), - kernel->irNodes().end(), - [](const std::unique_ptr& ir_node) { - return ir_node->isA() && isMagicZero(ir_node->as()); + const bool has_magic_zero = + std::any_of(kernel->vals().begin(), kernel->vals().end(), [](Val* val) { + return isMagicZero(val); }); if (!has_magic_zero) { @@ -122,19 +71,21 @@ std::vector insertMagicZero(const std::vector& exprs) { return MagicZeroInserter::insert(exprs); } -bool isMagicZero(kir::Val* val) { - auto ns = dynamic_cast(val); - if (ns == nullptr) { +bool isMagicZero(const Val* val) { + if (!val->isA()) { return false; } + auto ns = val->as(); return ns->dtype() == DataType::Int && ns->name() == std::string(kMagicZeroName); } -bool isProtectedWithMagicZero(kir::Val* val) { - auto def = dynamic_cast(val->definition()); - return def && def->operation() == BinaryOpType::Add && - isMagicZero(def->rhs()); +bool isProtectedWithMagicZero(const Val* val) { + if (val->definition() == nullptr || !val->definition()->isA()) { + return false; + } + auto bop = val->definition()->as(); + return bop->getBinaryOpType() == BinaryOpType::Add && isMagicZero(bop->rhs()); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_magic_zero.h b/torch/csrc/jit/codegen/cuda/lower_magic_zero.h index 03a37a46813..942a3302801 100644 --- a/torch/csrc/jit/codegen/cuda/lower_magic_zero.h +++ b/torch/csrc/jit/codegen/cuda/lower_magic_zero.h @@ -14,15 +14,15 @@ namespace cuda { //! zero update after every (outer most) loop nest with a compile time extent. //! //! This will make sure nvrtc does not aggressively save predicate and indices. -std::vector insertMagicZero(const std::vector& exprs); +std::vector insertMagicZero(const std::vector& exprs); //! Check if val is a reference to the magic zero variable -bool isMagicZero(kir::Val* val); +bool isMagicZero(const Val* val); //! Check if val is protected with magic zero. //! //! Specifically, this returns true if val is defined as "x + magic_zero". -bool isProtectedWithMagicZero(kir::Val* val); +bool isProtectedWithMagicZero(const Val* val); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp index b94c12c27c8..66b405ac8e2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp @@ -5,8 +5,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -18,85 +17,64 @@ namespace cuda { namespace { -class MisalignedVectorizationModifier { +class MisalignedVectorizationModifier : public kir::ExprMutator { public: - void process(const std::vector& exprs) { - FUSER_PERF_SCOPE( - "GpuLower::Lower::MisalignedVectorizationModifier::process"); - // Run through loop nests - // Find for-loops with misaligned vectorization domains - for (auto* expr : exprs) { - handle(expr); - } - } + MisalignedVectorizationModifier() = delete; - const std::unordered_map& replacementMap() const { - return expr_replacement_map_; + static std::vector processMisalignedVectorization( + const std::vector& exprs) { + FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization"); + MisalignedVectorizationModifier mvm(exprs); + return mvm.exprs_; } private: - void handle(kir::Expr* expr) { - if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } + MisalignedVectorizationModifier(const std::vector& exprs) { + FUSER_PERF_SCOPE("GpuLower::Lower::MisalignedVectorizationModifier"); + // Run through loop nests + // Find for-loops with misaligned vectorization domains + kir::ExprMutator::traverseAndInsert(exprs); } - void handle(kir::ForLoop* fl) { - for_loops_structure_.push_back(fl); - - // Make copy of exprs because we replace them inplace in fl - const auto exprs_copy = fl->body().exprs(); - + void handle(kir::ForLoop* fl) final { + kir::Scope* scope = scope_.empty() ? nullptr : scope_.back(); if (containsAnyDirectChildMisalignedVectorize(fl)) { - auto new_fl = handleMisalignedVectorize(for_loops_structure_, fl); - expr_replacement_map_.insert({fl, new_fl}); - } else { - for (auto expr : exprs_copy) { - handle(expr); - } - } + for_loops_.push_back(fl); + auto new_fl = handleMisalignedVectorize(for_loops_, fl); + for_loops_.pop_back(); - for_loops_structure_.pop_back(); - } - - void handle(kir::IfThenElse* ite) { - for (auto expr : ite->thenBody().exprs()) { - handle(expr); - } - for (auto expr : ite->elseBody().exprs()) { - handle(expr); + kir::ExprMutator::registerReplace(fl, new_fl, scope); + } else { + kir::ExprMutator::handle(fl); } } struct ReferenceTensors { // Input TensorView to Vectorize Set operation - kir::TensorView* in_tv = nullptr; + TensorView* in_tv = nullptr; // Output TensorView to Vectorize Set operation - kir::TensorView* out_tv = nullptr; + TensorView* out_tv = nullptr; // TensorView in global memory - kir::TensorView* global_tv = nullptr; + TensorView* global_tv = nullptr; // TensorView with vectorize IterDomain and not in global memory - kir::TensorView* vec_tv = nullptr; + TensorView* vec_tv = nullptr; }; - ReferenceTensors getReferenceTensors(kir::Expr* vectorized_expr) { + ReferenceTensors getReferenceTensors(Expr* vectorized_expr) { TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr); TORCH_INTERNAL_ASSERT( - vectorized_expr->outputs().front()->isA()); - TORCH_INTERNAL_ASSERT( - vectorized_expr->inputs().front()->isA()); + vectorized_expr->outputs().front()->isA()); + TORCH_INTERNAL_ASSERT(vectorized_expr->inputs().front()->isA()); - auto in_tv = vectorized_expr->inputs().front()->as(); - auto out_tv = vectorized_expr->outputs().front()->as(); + auto in_tv = vectorized_expr->inputs().front()->as(); + auto out_tv = vectorized_expr->outputs().front()->as(); const bool global_vectorize_write_op = - (out_tv->memoryType() == MemoryType::Global && - in_tv->memoryType() == MemoryType::Local); + (out_tv->getMemoryType() == MemoryType::Global && + in_tv->getMemoryType() == MemoryType::Local); const bool global_vectorize_read_op = - (out_tv->memoryType() == MemoryType::Local && - in_tv->memoryType() == MemoryType::Global); + (out_tv->getMemoryType() == MemoryType::Local && + in_tv->getMemoryType() == MemoryType::Global); TORCH_INTERNAL_ASSERT( global_vectorize_write_op || global_vectorize_read_op, "Unsupported vectorize memory configuration detected."); @@ -104,25 +82,26 @@ class MisalignedVectorizationModifier { // TensorView on global memory. This is the tensor that may have // a non-aligned base address. auto global_tv = - (out_tv->memoryType() == MemoryType::Global) ? out_tv : in_tv; + (out_tv->getMemoryType() == MemoryType::Global) ? out_tv : in_tv; // TensorView with the misaligned vec iterDomain. It is the consumer // of vectorized load or the producer of vectorized store. It is // assumed that when the output TV is not on global memory, this // expression is a vectorized load, so the output TV is vec_tv. - auto vec_tv = (out_tv->memoryType() != MemoryType::Global) ? out_tv : in_tv; + auto vec_tv = + (out_tv->getMemoryType() != MemoryType::Global) ? out_tv : in_tv; return {in_tv, out_tv, global_tv, vec_tv}; } struct VectorizeData { - kir::Val* vector_size = nullptr; - kir::Val* shift = nullptr; - kir::Val* extent = nullptr; - kir::Val* remainder = nullptr; - kir::Val* extent_minus_remainder = nullptr; - kir::Val* last_root_domain_index = nullptr; - kir::Val* last_root_domain_index_shift = nullptr; + Val* vector_size = nullptr; + Val* shift = nullptr; + Val* extent = nullptr; + Val* remainder = nullptr; + Val* extent_minus_remainder = nullptr; + Val* last_root_domain_index = nullptr; + Val* last_root_domain_index_shift = nullptr; }; // Create constants for handling misaligned addresses @@ -130,48 +109,43 @@ class MisalignedVectorizationModifier { const std::vector& for_loop_structure, const ReferenceTensors& tensors, kir::IfThenElse* parent_scope_ite) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - // Generate vectorize index - auto indices = (tensors.out_tv->memoryType() == MemoryType::Global) - ? Index::getConsumerStridedIndices( - tensors.out_tv->fuserTv(), for_loop_structure) + auto indices = (tensors.out_tv->getMemoryType() == MemoryType::Global) + ? Index::getConsumerStridedIndices(tensors.out_tv, for_loop_structure) : Index::getProducerStridedIndices( - tensors.in_tv->fuserTv(), - tensors.out_tv->fuserTv(), - for_loop_structure); + tensors.in_tv, tensors.out_tv, for_loop_structure); // >>>>>>>>>>>>> // Number of elements in vectorize access auto vector_size = - tensors.vec_tv->domain()->domain().back()->extent()->as(); + tensors.vec_tv->domain()->domain().back()->extent()->as(); // Size of memory type for the elements - kir::Int* data_size_in_bytes = - ir_builder.create(dataTypeSize(tensors.vec_tv->dtype())); + Int* data_size_in_bytes = + IrBuilder::create(dataTypeSize(tensors.vec_tv->dtype())); // The number of bytes in the vectorize access auto vector_size_in_bytes = - ir_builder.mulExpr(vector_size, data_size_in_bytes); + IrBuilder::mulExpr(vector_size, data_size_in_bytes); - auto index = ir_builder.create( - tensors.global_tv->fuserTv(), indices); + auto index = + IrBuilder::create(tensors.global_tv, indices); auto address = createNamedScalarFromValue( parent_scope_ite->thenBody(), index, "address", true); // offset_size = (address % vector_size_bytes) / data_type_size_bytes // shift_init = vector_size - offset_size - auto a = ir_builder.modExpr(address, vector_size_in_bytes); - auto b = ir_builder.divExpr(a, data_size_in_bytes); - auto c = ir_builder.subExpr(vector_size, b); + auto a = IrBuilder::modExpr(address, vector_size_in_bytes); + auto b = IrBuilder::divExpr(a, data_size_in_bytes); + auto c = IrBuilder::subExpr(vector_size, b); auto shift_init = createNamedScalarFromValue( parent_scope_ite->thenBody(), c, "shift_val"); // shift = (shift_init == vector_size) ? 0 : shift_init // The number of elements until the first aligned address - auto shift_pred = ir_builder.eqExpr(shift_init, vector_size); - auto shift_val = - ir_builder.whereExpr(shift_pred, ir_builder.zeroVal(), shift_init); + auto shift_pred = IrBuilder::eqExpr(shift_init, vector_size); + auto shift_val = IrBuilder::whereExpr( + shift_pred, GpuLower::current()->kernel()->zeroVal(), shift_init); // >>>>>>>>>>>>> auto shift = createNamedScalarFromValue( @@ -183,13 +157,13 @@ class MisalignedVectorizationModifier { // remainder = (extent - shift) % vector_size // The number of elements remaining not accessed by vectorized operations - auto remaining_extent = ir_builder.subExpr(extent, shift); - auto remainder_val = ir_builder.modExpr(remaining_extent, vector_size); + auto remaining_extent = IrBuilder::subExpr(extent, shift); + auto remainder_val = IrBuilder::modExpr(remaining_extent, vector_size); auto remainder = createNamedScalarFromValue( parent_scope_ite->thenBody(), remainder_val, "remainder"); // (extent - remainder) is the upper-bound for the vectorize section - auto extent_remainder_val = ir_builder.subExpr(extent, remainder); + auto extent_remainder_val = IrBuilder::subExpr(extent, remainder); // >>>>>>>>>>>>> auto extent_minus_remainder = createNamedScalarFromValue( @@ -203,7 +177,7 @@ class MisalignedVectorizationModifier { // >>>>>>>>>>>>> auto last_root_domain_index_shift = - ir_builder.addExpr(last_root_domain_index, shift); + IrBuilder::addExpr(last_root_domain_index, shift); return { vector_size, @@ -220,20 +194,18 @@ class MisalignedVectorizationModifier { kir::IfThenElse* createVectorizeSection( const std::vector& child_loops, const VectorizeData& params) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto vectorized_child_loops = cloneForLoops( child_loops, params.vector_size, nullptr, true, params.shift); // Vectorize Range: [shift - (extent-remainder)) // (last_root_domain_index + shift) < (extent - remainder) - kir::Val* vectorize_cond = ir_builder.ltExpr( + Val* vectorize_cond = IrBuilder::ltExpr( params.last_root_domain_index_shift, params.extent_minus_remainder); kir::Predicate* vectorize_pred = - ir_builder.create(vectorize_cond->as()); + IrBuilder::create(vectorize_cond->as()); kir::IfThenElse* vectorize_ite = - ir_builder.create(vectorize_pred); + IrBuilder::create(vectorize_pred); for (auto cloned_loop : vectorized_child_loops) { vectorize_ite->thenBody().push_back(cloned_loop); @@ -247,20 +219,19 @@ class MisalignedVectorizationModifier { kir::IfThenElse* createInitialSection( const std::vector& child_loops, const VectorizeData& params) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto pre_child_loops = cloneForLoops( child_loops, params.vector_size, params.shift, false, nullptr); // Initial Range: [0 - shift) // last_root_domain_index == 0 - kir::Val* initial_cond = - ir_builder.eqExpr(params.last_root_domain_index, ir_builder.zeroVal()); + Val* initial_cond = IrBuilder::eqExpr( + params.last_root_domain_index, + GpuLower::current()->kernel()->zeroVal()); kir::Predicate* initial_pred = - ir_builder.create(initial_cond->as()); + IrBuilder::create(initial_cond->as()); kir::IfThenElse* initial_ite = - ir_builder.create(initial_pred); + IrBuilder::create(initial_pred); for (auto cloned_loop : pre_child_loops) { initial_ite->thenBody().push_back(cloned_loop); @@ -274,23 +245,21 @@ class MisalignedVectorizationModifier { kir::IfThenElse* createRemainderSection( const std::vector& child_loops, const VectorizeData& params) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto post_child_loops = cloneForLoops( child_loops, params.vector_size, params.remainder, false, params.shift); // Remainder Range: [(extent-remainder) - extent) // (extent - remainder) <= last_root_domain_index + shift < extent - kir::Val* lower_bound = ir_builder.geExpr( + Val* lower_bound = IrBuilder::geExpr( params.last_root_domain_index_shift, params.extent_minus_remainder); - kir::Val* upper_bound = - ir_builder.ltExpr(params.last_root_domain_index_shift, params.extent); - kir::Val* remainder_cond = ir_builder.andExpr(lower_bound, upper_bound); + Val* upper_bound = + IrBuilder::ltExpr(params.last_root_domain_index_shift, params.extent); + Val* remainder_cond = IrBuilder::andExpr(lower_bound, upper_bound); kir::Predicate* remainder_pred = - ir_builder.create(remainder_cond->as()); + IrBuilder::create(remainder_cond->as()); kir::IfThenElse* remainder_ite = - ir_builder.create(remainder_pred); + IrBuilder::create(remainder_pred); for (auto cloned_loop : post_child_loops) { remainder_ite->thenBody().push_back(cloned_loop); @@ -302,8 +271,6 @@ class MisalignedVectorizationModifier { kir::ForLoop* handleMisalignedVectorize( std::vector for_loop_structure, const kir::ForLoop* parent_for_loop) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto child_loops = findChildForLoops(parent_for_loop); // Assumption: All vectorize operations have the same shift @@ -315,17 +282,19 @@ class MisalignedVectorizationModifier { // The parent_for_loop contains allocate, read, compute, write operations const auto new_parent_for_loop = - ir_builder.create(parent_for_loop); + IrBuilder::create(parent_for_loop); // Transfer all expressions except for-loops to new parent for-loop // All expressions are placed at the beginning of the new for-loop - moveExprsExceptForLoops(parent_for_loop, new_parent_for_loop); + copyExprsExceptForLoops(parent_for_loop, new_parent_for_loop); // Get the predicate for all but the last root domain - auto pred_except_last_root_domain = ir_builder.create( - PredicateType::Misaligned, vectorized_expr, ir_builder.trueVal()); + auto pred_except_last_root_domain = IrBuilder::create( + PredicateType::Misaligned, + vectorized_expr, + GpuLower::current()->kernel()->trueVal()); kir::IfThenElse* pred_ite = - ir_builder.create(pred_except_last_root_domain); + IrBuilder::create(pred_except_last_root_domain); new_parent_for_loop->body().push_back(pred_ite); auto constants = createVectorizeConstants( @@ -351,17 +320,17 @@ class MisalignedVectorizationModifier { // Determine that the expression is UnaryOpType::Set AND // the output TensorView domain is vectorized - bool isVectorizeSetOp(kir::ForLoop* fl, kir::Expr* expr) { - if (fl->iter_domain()->parallelType() != + bool isVectorizeSetOp(kir::ForLoop* fl, Expr* expr) { + if (fl->iter_domain()->getParallelType() != ParallelType::MisalignedVectorize) { return false; } - if (expr->isA()) { - auto unaryOp = expr->as(); - if (unaryOp->out()->isA()) { - auto out_tv = unaryOp->out()->as(); - return unaryOp->operation() == UnaryOpType::Set && + if (expr->isA()) { + auto unaryOp = expr->as(); + if (unaryOp->out()->isA()) { + auto out_tv = unaryOp->out()->as(); + return unaryOp->getUnaryOpType() == UnaryOpType::Set && out_tv->domain()->hasVectorize(); } } @@ -374,15 +343,14 @@ class MisalignedVectorizationModifier { // vectorize flag - Do not generate for loop header // shift value - Add shift to global indices generated within for loop std::vector cloneForLoops( - const std::vector& for_loops, - kir::Val* loop_stop, - kir::Val* pred_stop, + const std::vector& for_loops_, + Val* loop_stop, + Val* pred_stop, bool vectorize, - kir::Val* vectorize_shift) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); + Val* vectorize_shift) { std::vector cloned_for_loops; - for (auto fl : for_loops) { + for (auto fl : for_loops_) { auto first_expr = fl->body().exprs().front(); bool has_vectorize_op = isVectorizeSetOp(fl, first_expr); @@ -391,12 +359,12 @@ class MisalignedVectorizationModifier { TORCH_INTERNAL_ASSERT( !has_vectorize_op || fl->body().exprs().size() == 1); - const auto new_loop = ir_builder.create( + const auto new_loop = IrBuilder::create( fl->iter_domain(), fl->index(), - ir_builder.zeroVal(), + GpuLower::current()->kernel()->zeroVal(), loop_stop, - ir_builder.oneVal(), + GpuLower::current()->kernel()->oneVal(), vectorize && has_vectorize_op, vectorize_shift, fl->isUnrollRequired()); @@ -406,9 +374,9 @@ class MisalignedVectorizationModifier { // Predicate the loop body if pred_stop is not null. This is to // make sure the loop itself is completely unrollable. if (pred_stop != nullptr) { - auto body_pred = ir_builder.create( - ir_builder.ltExpr(new_loop->index(), pred_stop)->as()); - auto body_ite = ir_builder.create(body_pred); + auto body_pred = IrBuilder::create( + IrBuilder::ltExpr(new_loop->index(), pred_stop)->as()); + auto body_ite = IrBuilder::create(body_pred); body->push_back(body_ite); body = &body_ite->thenBody(); } @@ -423,7 +391,7 @@ class MisalignedVectorizationModifier { } // Add all expressions except for loops to new parent for loop - void moveExprsExceptForLoops( + void copyExprsExceptForLoops( const kir::ForLoop* for_loop, kir::ForLoop* new_loop) { std::vector loops; @@ -448,10 +416,10 @@ class MisalignedVectorizationModifier { // Find the first vectorize set - either read or write // Add child For-Loop to for_loop_structure // Enable vectorize flag in child For-Loop - kir::Expr* findFirstVectorizedSetOp( + Expr* findFirstVectorizedSetOp( std::vector& for_loop_structure, - const std::vector& for_loops) { - for (auto fl : for_loops) { + const std::vector& for_loops_) { + for (auto fl : for_loops_) { auto first_expr = fl->body().exprs().front(); bool has_vectorize_op = isVectorizeSetOp(fl, first_expr); if (has_vectorize_op) { @@ -463,38 +431,31 @@ class MisalignedVectorizationModifier { } // Get full extent for the inner-most, merged root domain - kir::Val* getVectorizeExtent( - kir::TensorView* producer_tv, - kir::TensorView* consumer_tv) { + Val* getVectorizeExtent(TensorView* producer_tv, TensorView* consumer_tv) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - - auto consumer_fuser_tv = consumer_tv->fuserTv(); - auto producer_fuser_tv = producer_tv->fuserTv(); - auto p2c = - PairwiseRootDomainMap(producer_fuser_tv, consumer_fuser_tv) - .mapProducerToConsumer( - producer_fuser_tv->domain(), consumer_fuser_tv->domain()); + auto p2c = PairwiseRootDomainMap(producer_tv, consumer_tv) + .mapProducerToConsumer( + producer_tv->domain(), consumer_tv->domain()); auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo( - {consumer_fuser_tv->domain()->domain().begin() + - consumer_fuser_tv->getComputeAtPosition(), - consumer_fuser_tv->domain()->domain().end()}); + {consumer_tv->domain()->domain().begin() + + consumer_tv->getComputeAtPosition(), + consumer_tv->domain()->domain().end()}); auto producer_root_right_of_ca_domains = IterVisitor::getInputsTo( - {producer_fuser_tv->domain()->domain().begin() + - producer_fuser_tv->getComputeAtPosition(), - producer_fuser_tv->domain()->domain().end()}); + {producer_tv->domain()->domain().begin() + + producer_tv->getComputeAtPosition(), + producer_tv->domain()->domain().end()}); - const auto& consumer_contig = consumer_fuser_tv->domain()->contiguity(); - const auto& producer_contig = producer_fuser_tv->domain()->contiguity(); + const auto& consumer_contig = consumer_tv->domain()->contiguity(); + const auto& producer_contig = producer_tv->domain()->contiguity(); - auto producer_root_domain = producer_fuser_tv->getMaybeRFactorDomain(); + auto producer_root_domain = producer_tv->getMaybeRFactorDomain(); // Calculate extent of merged root domains - kir::Val* extent = nullptr; + Val* extent = nullptr; auto consumer_root_idx = - int(consumer_fuser_tv->getMaybeRFactorDomain().size()) - 1; + int(consumer_tv->getMaybeRFactorDomain().size()) - 1; for (int i = int(producer_root_domain.size()) - 1; i >= 0; --i) { auto producer_root_id = producer_root_domain.at(i); @@ -533,11 +494,10 @@ class MisalignedVectorizationModifier { // We now know it's safe to extend the vectorization domain to these // axes. It shouldn't matter whether producer or consumer is used. - auto consumer_extent = gpu_lower->lowerValue(consumer_root_id->extent()); if (extent == nullptr) { - extent = consumer_extent; + extent = consumer_root_id->extent(); } else { - extent = ir_builder.mulExpr(extent, consumer_extent); + extent = IrBuilder::mulExpr(extent, consumer_root_id->extent()); } // If it's not contiguous, extending the vectorization domain @@ -554,57 +514,37 @@ class MisalignedVectorizationModifier { return extent; } - kir::Val* createNamedScalarFromValue( + Val* createNamedScalarFromValue( kir::Scope& body, - kir::Val* val, + Val* val, const std::string& name, bool address = false) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto namedScalar = (address) ? ir_builder.addressExprNamedScalar(name, val) - : ir_builder.setExprNamedScalar(name, val); + auto namedScalar = (address) ? IrBuilder::addressExprNamedScalar(name, val) + : IrBuilder::setExprNamedScalar(name, val); TORCH_INTERNAL_ASSERT(namedScalar->definition() != nullptr); - auto alloc = ir_builder.create( - namedScalar, MemoryType::Local, ir_builder.oneVal()); + auto alloc = IrBuilder::create( + namedScalar, + MemoryType::Local, + GpuLower::current()->kernel()->oneVal()); body.push_back(alloc); body.push_back(namedScalar->definition()); return namedScalar; } - - private: - // We will track which loops in the incoming IR will be replaced and by what - std::unordered_map expr_replacement_map_; - - // A depth-first ordering of nested for loops - // It is used for indexing and predicate generation - std::vector for_loops_structure_; }; } // namespace -std::vector processMisalignedVectorization( - Fusion* fusion, - const std::vector& exprs) { - FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization"); - - MisalignedVectorizationModifier mvm; - mvm.process(exprs); - - std::vector mutated_exprs; - mutated_exprs.reserve(exprs.size()); - for (auto expr : exprs) { - mutated_exprs.push_back( - ir_utils::applyReplacements(mvm.replacementMap(), expr)); - } - - return mutated_exprs; +std::vector processMisalignedVectorization( + const std::vector& exprs) { + return MisalignedVectorizationModifier::processMisalignedVectorization(exprs); } bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl) { for (auto expr : fl->body().exprs()) { if (expr->isA()) { auto child_fl = expr->as(); - if (child_fl->iter_domain()->parallelType() == + if (child_fl->iter_domain()->getParallelType() == ParallelType::MisalignedVectorize) { return true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h index 588d3787752..bd7ae19d93a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h @@ -1,5 +1,5 @@ #pragma once -#include +#include #include @@ -106,9 +106,8 @@ namespace cuda { //! } //! } //! -std::vector processMisalignedVectorization( - Fusion* fusion, - const std::vector& exprs); +std::vector processMisalignedVectorization( + const std::vector& exprs); bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 838d5d85d9e..cd34c56b510 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -7,8 +7,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -23,27 +22,26 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier { +class ConditionalFromPredicateModifier : public kir::IrVisitor { public: - ConditionalFromPredicateModifier(const std::vector& exprs) { + ConditionalFromPredicateModifier() = delete; + + static std::vector fillPredicates(const std::vector& exprs) { + ConditionalFromPredicateModifier cfpm(exprs); + return cfpm.exprs_; + } + + private: + ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - for (auto* expr : exprs) { - handle(expr); - } + kir::IrVisitor::handle(exprs); } - const std::unordered_map& replacementMap() const { - return expr_replacement_map_; - } + using kir::IrVisitor::handle; - private: - void handle(kir::Expr* expr) { - if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (expr != nullptr && expr->predicate() != nullptr) { + void handle(Expr* expr) final { + if (expr != nullptr && expr->predicate() != nullptr) { // Replace expr predicate with bool conditional auto conditional = generateConditional(expr->predicate()); TORCH_INTERNAL_ASSERT(conditional != nullptr); @@ -51,9 +49,11 @@ class ConditionalFromPredicateModifier { TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr); setWritePredicate(expr, conditional); } + + kir::IrVisitor::handle(expr); } - void setWritePredicate(kir::Expr* expr, kir::Bool* read_cond) { + void setWritePredicate(Expr* expr, Bool* read_cond) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); if (write_cond) { @@ -66,46 +66,25 @@ class ConditionalFromPredicateModifier { } } - void handle(kir::ForLoop* fl) { - for_loops_structure_.push_back(fl); - - const auto exprs_copy = fl->body().exprs(); - for (auto expr : exprs_copy) { - handle(expr); - } - - for_loops_structure_.pop_back(); - } - - void handle(kir::IfThenElse* ite) { + void handle(kir::IfThenElse* ite) final { TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr); // If ite already has Bool conditional, handle internal expressions // Otherwise, generate conditional and update predicate - if (ite->predicate()->hasValue()) { - const auto then_exprs_copy = ite->thenBody().exprs(); - for (auto expr : then_exprs_copy) { - handle(expr); - } - - const auto else_exprs_copy = ite->elseBody().exprs(); - for (auto expr : else_exprs_copy) { - handle(expr); - } - } else { + if (!ite->predicate()->hasValue()) { auto conditional = generateConditional(ite->predicate()); TORCH_INTERNAL_ASSERT(conditional != nullptr); - TORCH_INTERNAL_ASSERT(conditional->isA()); + TORCH_INTERNAL_ASSERT(conditional->isA()); // Update bool conditional in-place ite->predicate()->setValue(conditional); - handle(ite); TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr); } + kir::IrVisitor::handle(ite); } // Generate conditional according to PredicateType - kir::Bool* generateConditional(kir::Predicate* pred) { + Bool* generateConditional(kir::Predicate* pred) { switch (pred->predicate_type()) { case PredicateType::Inline: case PredicateType::ReductionWrite: @@ -114,15 +93,16 @@ class ConditionalFromPredicateModifier { case PredicateType::Padding: { return PredicateCompute::getInlinePredicate( pred->expr(), - for_loops_structure_, + for_loops_, pred->thread_pred(), pred->predicate_type()); } case PredicateType::Vectorize: { std::vector outer_loops; kir::ForLoop* vectorized_loop = nullptr; - for (auto loop : for_loops_structure_) { - if (loop->iter_domain()->parallelType() == ParallelType::Vectorize) { + for (auto loop : for_loops_) { + if (loop->iter_domain()->getParallelType() == + ParallelType::Vectorize) { vectorized_loop = loop; break; } else { @@ -134,8 +114,7 @@ class ConditionalFromPredicateModifier { return UnswitchPredicate::get(outer_loops, vectorized_loop); } case PredicateType::Unswitch: { - return UnswitchPredicate::get( - for_loops_structure_, pred->unrolled_loop()); + return UnswitchPredicate::get(for_loops_, pred->unrolled_loop()); } case PredicateType::Manual: { return pred->value(); @@ -145,33 +124,13 @@ class ConditionalFromPredicateModifier { } return nullptr; } - - private: - // We will track which loops in the incoming IR will be replaced and by what - std::unordered_map expr_replacement_map_; - - // A depth-first ordering of nested for loops - // It is used for indexing and predicate generation - std::vector for_loops_structure_; }; } // namespace -std::vector generateConditionalFromPredicate( - Fusion* fusion, - const std::vector& exprs) { - FUSER_PERF_SCOPE("GpuLower::Lower::generateConditionalFromPredicate"); - - ConditionalFromPredicateModifier p2cm(exprs); - - std::vector mutated_exprs; - mutated_exprs.reserve(exprs.size()); - for (auto expr : exprs) { - mutated_exprs.push_back( - ir_utils::applyReplacements(p2cm.replacementMap(), expr)); - } - - return mutated_exprs; +std::vector generateConditionalFromPredicate( + const std::vector& exprs) { + return ConditionalFromPredicateModifier::fillPredicates(exprs); } namespace { @@ -225,17 +184,14 @@ class PredicateAnalyzer : public OptOutDispatch { return needs_predicate_; } - using OptOutDispatch::handle; - void handle(IterDomain* consumer_id) override { // The traversal should have ended if needs_predicate_ was true TORCH_INTERNAL_ASSERT(!needs_predicate_); // If consumer_id is not going to be materialized as a loop (e.g., // broadcast), no need to predicate - const auto gpu_lower = GpuLower::current(); if (consumer_id->isBroadcast() || - gpu_lower->trivialReductionInfo().isDerived(consumer_id)) { + GpuLower::current()->trivialReductionInfo().isDerived(consumer_id)) { return; } @@ -250,7 +206,7 @@ class PredicateAnalyzer : public OptOutDispatch { return; } - handle(consumer_id->definition()); + OptOutDispatch::handle(consumer_id->definition()); } // If it splits the input axis evenly, proceeds to check the input @@ -291,7 +247,7 @@ class PredicateAnalyzer : public OptOutDispatch { } // namespace bool PredicateElimination::needsPredicate(Expr* expr) const { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } @@ -394,7 +350,7 @@ bool PredicateElimination::needsPredicate(Expr* expr) const { } void PredicateElimination::handle(Expr* expr) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return; } @@ -491,7 +447,7 @@ bool PredicateElimination::setReductionInitValue( bool PredicateElimination::canOmitPredicate(const Expr* expr) const { TORCH_INTERNAL_ASSERT(expr != nullptr); - const auto out_tv = ir_utils::getTVOutput(expr); + const auto out_tv = ir_utils::getTvOutput(expr); TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression"); // No need to predicate local tensors to which a scalar is assigned if (out_tv->getMemoryType() == MemoryType::Local) { @@ -508,38 +464,17 @@ bool PredicateElimination::canOmitPredicate(const Expr* expr) const { return false; } -bool PredicateElimination::canOmitPredicate(const kir::Expr* kir_expr) const { - TORCH_INTERNAL_ASSERT(kir_expr != nullptr); - const auto out_tv = ir_utils::getTVOutput(kir_expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression"); - // No need to predicate local tensors to which a scalar is assigned - if (out_tv->memoryType() == MemoryType::Local) { - if (auto uop = dynamic_cast(kir_expr)) { - if (uop->operation() == UnaryOpType::Set && uop->in()->isScalar()) { - return true; - } - } - } - const auto fuser_tv = out_tv->fuserTv(); - if (fuser_tv == nullptr) { - return false; - } - return canOmitPredicate(fuser_tv->definition()); -} - -kir::Val* PredicateElimination::getInitValue(TensorView* tv) const { +Val* PredicateElimination::getInitValue(TensorView* tv) const { auto it = init_value_map_.find(tv); if (it == init_value_map_.end()) { return nullptr; } - const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); auto init_val = it->second; if (init_val == nullptr) { // No reduction restriction. Just use zero - return ir_builder.zeroVal(); + return GpuLower::current()->kernel()->zeroVal(); } else { - return gpu_lower->lowerValue(init_val); + return init_val; } } diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.h b/torch/csrc/jit/codegen/cuda/lower_predicate.h index 393d0fa5c18..c0a1f702f7b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.h @@ -1,5 +1,5 @@ #pragma once -#include +#include #include #include @@ -13,9 +13,8 @@ namespace cuda { //! Update predicates with valid bool conditionals //! -std::vector generateConditionalFromPredicate( - Fusion* fusion, - const std::vector& exprs); +std::vector generateConditionalFromPredicate( + const std::vector& exprs); class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { public: @@ -26,13 +25,8 @@ class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { //! \param expr Tensor expression bool canOmitPredicate(const Expr* expr) const; - //! True if expr does not need a predicate - //! - //! \param expr KIR tensor expr - bool canOmitPredicate(const kir::Expr* expr) const; - //! Value to initialize out-of-bound regions - kir::Val* getInitValue(TensorView* tv) const; + Val* getInitValue(TensorView* tv) const; //! Dump to string for debugging std::string toString() const; @@ -40,7 +34,7 @@ class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { private: using IterVisitor::handle; - void handle(Expr* expr) override; + void handle(Expr* expr) final; //! Set a value to initialize out-of-bound regions bool setDefaultInitValue(TensorView* tv); diff --git a/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp new file mode 100644 index 00000000000..582b6d91d06 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_replace_size.cpp @@ -0,0 +1,288 @@ +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { +// Going to generate a map of tensor view root domain extents to reduce the +// number used during lowering. For example if we have: +// +// T2[i0, i1] = T1[i0, i1] + T2[i2, i3] +// +// We know it would be safe to use: +// +// T2[i0, i1] = T1[i0, i1] + T2[i0, i1] +// +// And that way we don't generate T2.size[0] and T2.size[1], instead we will +// reuse T1.size[0] and T1.size[1] +// This is important when doing CSE as T2 and T1 would otherwise look like +// they're using different values, even though we know they're the same +// +// There's some duplicate logic here that's in computeAt map, but it's not so +// concice there to pull out. May want to consider making this mapping its own +// class especially as it may be useful during scheduling. +std::unordered_map getSimplificationMap(Fusion* fusion) { + std::list> disjoint_root_sets; + std::unordered_map*> + id_to_disjoint_root_set; + + auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set]( + IterDomain* id0, IterDomain* id1) { + if (id0->isBroadcast() || id1->isBroadcast()) { + return; + } + + auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0); + auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1); + bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end(); + bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end(); + + if (set_0_found && set_1_found) { + if (disjoint_set_0_it->second == disjoint_set_1_it->second) { + return; + } + // merge second disjoint set into first + auto* set_0 = disjoint_set_0_it->second; + auto* set_1 = disjoint_set_1_it->second; + for (auto id : *set_1) { + set_0->emplace(id); + id_to_disjoint_root_set[id] = set_0; + } + // remove second set from disjoint_root_sets + disjoint_root_sets.erase(std::find( + disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1)); + } else if (set_0_found || set_1_found) { + auto existing_set = + set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second; + auto to_add_id = set_0_found ? id1 : id0; + existing_set->emplace(to_add_id); + id_to_disjoint_root_set[to_add_id] = existing_set; + // add entry into existing set + } else { + // create new set entry + disjoint_root_sets.emplace_back(std::unordered_set()); + auto* new_set = &disjoint_root_sets.back(); + new_set->emplace(id0); + new_set->emplace(id1); + id_to_disjoint_root_set[id0] = new_set; + id_to_disjoint_root_set[id1] = new_set; + } + }; + + auto fusion_vals = fusion->usedMathVals(); + for (auto producer_tv : ir_utils::filterByType(fusion_vals)) { + auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); + for (auto consumer_tv : consumer_tvs) { + auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); + auto c2p_root_map = pairwise_map.mapConsumerToProducer( + consumer_tv->domain(), producer_tv->domain()); + for (auto entry : c2p_root_map) { + auto c_id = entry.first; + auto p_id = entry.second; + map_root_ids(p_id, c_id); + } + } + } + + // Map each set to an input ID (if it exists) that has the smallest ->name() + // entry value + std::unordered_map*, IterDomain*> + set_to_input_id; + + // Loop over the root domains, of the inputs to the fusion. Pick an input ID + // to use as the representative ID of the collected sets. Only consider inputs + // as those are the ones that map to values like "T0.size[1]". They are he + // ID's that propagated their extents into the problem. We could also check + // the outputs as we do have C++ examples of using output dimensions for the + // problem size instead of inputs. However, we don't do anything where we can + // translate to those kinds of kernels integrated into PyTorch. + for (auto input_tv : ir_utils::filterByType(fusion->inputs())) { + for (auto id : + TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) { + auto id_set_it = id_to_disjoint_root_set.find(id); + if (id_set_it == id_to_disjoint_root_set.end()) { + continue; + } + auto* id_set = id_set_it->second; + if (set_to_input_id.find(id_set) == set_to_input_id.end()) { + set_to_input_id[id_set] = id; + } else { + auto input_id_of_set = set_to_input_id.at(id_set); + // Swap id's if new name is less than previously set + bool swap_ids = id->name() < input_id_of_set->name(); + // If new id is a const scalar but previously was'nt use the const + // scalar + swap_ids = swap_ids || + (id->extent()->isConstScalar() && + !input_id_of_set->extent()->isConstScalar()); + // If previous scalar was const and new isn't, don't swap + swap_ids = swap_ids && + !(input_id_of_set->extent()->isConstScalar() && + !id->extent()->isConstScalar()); + + if (swap_ids) { + set_to_input_id[id_set] = id; + } + } + } + } + + // Finally make map from ID extents to the representitive ID extent. + std::unordered_map extent_to_min_input_id_extent; + for (auto entry : set_to_input_id) { + auto* set = entry.first; + auto input_id = entry.second; + for (auto id : *set) { + extent_to_min_input_id_extent[id->extent()] = input_id->extent(); + } + } + return extent_to_min_input_id_extent; +} + +std::vector allLeafOuts(Fusion* fusion) { + auto exprs = StmtSort::getExprs(fusion, true); + std::unordered_set inputs; + std::unordered_set outputs; + std::vector ordered_outputs; + for (auto expr : exprs) { + inputs.insert(expr->inputs().begin(), expr->inputs().end()); + outputs.insert(expr->outputs().begin(), expr->outputs().end()); + ordered_outputs.insert( + ordered_outputs.end(), expr->outputs().begin(), expr->outputs().end()); + } + for (auto input : inputs) { + outputs.erase(input); + } + + std::vector ordered_leaf_outs; + for (auto out : ordered_outputs) { + if (outputs.find(out) != outputs.end()) { + ordered_leaf_outs.push_back(out); + } + } + return ordered_leaf_outs; +} + +class ValReplacementMutator : private OptOutMutator { + public: + ValReplacementMutator( + Fusion* fusion, + const std::unordered_map& replacement_map) + : replacement_map_(replacement_map) { + FusionGuard fg(fusion); + + // Welford makes this a little annoying since it holds a count which is + // typically not used by anything else. If we don't grab that count, then it + // would be a tensorview that doesn't get updated extents. Therefore, first + // grab all leaves towards outputs and grab stmts from there. + auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); + for (auto stmt : stmts) { + mutate(stmt); + } + } + + private: + using OptOutMutator::mutate; + void mutate(Val* val) final { + if (replacement_map_.find(val) == replacement_map_.end()) { + return OptOutMutator::mutate(val); + } + auto replaced_val = replacement_map_.at(val); + registerMutation(val, replaced_val); + } + + const std::unordered_map& replacement_map_; +}; + +} // namespace + +void replaceSymbolicSizes(Fusion* fusion) { + FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes"); + std::unordered_map tensor_dim_map; + + // Grab inputs and outputs + std::vector inputs_and_outputs; + for (auto val : fusion->inputs()) { + if (ir_utils::isTV(val)) { + inputs_and_outputs.push_back(val->as()); + } + } + // Symbolic size is necessary for outputs if there are no inputs. + // Otherwise infer output sizes from the inputs via expression evaluation. + if (fusion->inputs().empty()) { + for (auto val : fusion->outputs()) { + if (ir_utils::isTV(val)) { + inputs_and_outputs.push_back(val->as()); + } + } + } + + // Generate map for all tensorview root domain values to map them to symbolic + // values. i.e. T0->getRootDomain()[0] would map to a named scalar + // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir. + for (TensorView* tv : inputs_and_outputs) { + // Replace the domain with one based on Ti.size[j] + const std::vector& root_td = tv->getRootDomain(); + + size_t dim = 0; + for (auto id : root_td) { + Val* orig_size = id->extent(); + + // Output sizes could have reduction axes, which isn't what gets output. + // NOLINTNEXTLINE(bugprone-branch-clone) + if (id->isReduction() || + (id->getIterType() == IterType::BroadcastWithoutStride)) { + continue; + } else if ( + id->isRFactorProduct() || + // NOLINTNEXTLINE(bugprone-branch-clone) + (id->getIterType() == IterType::BroadcastWithStride) || + orig_size->isConstScalar()) { + dim++; + continue; + } + + // Currently turn off this part for inputs of segmented fusion, + // since FusionKernelRuntime will provide these as integer inputs + if (tensor_dim_map.find(orig_size) == tensor_dim_map.end() && + !orig_size->isFusionInput() && !orig_size->isConstScalar()) { + std::stringstream ss; + ss << "T" << tv->name() << ".size[" << dim++ << "]"; + tensor_dim_map[orig_size] = IrBuilder::create( + ss.str(), orig_size->getDataType().value()); + } else { + dim++; + } + } + } + + // Use a minimal number of sizes from provided tensors. + auto extent_simplification_map = getSimplificationMap(fusion); + for (auto extent_entry : extent_simplification_map) { + auto orig_extent = extent_entry.first; + auto simplified_extent = extent_entry.second; + if (tensor_dim_map.count(orig_extent)) { + if (tensor_dim_map.count(simplified_extent)) { + tensor_dim_map[orig_extent] = tensor_dim_map[simplified_extent]; + } else { + tensor_dim_map[orig_extent] = simplified_extent; + } + } + } + + // Run mutation on the fusion with the tensor_dim_map + ValReplacementMutator(fusion, tensor_dim_map); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_replace_size.h b/torch/csrc/jit/codegen/cuda/lower_replace_size.h new file mode 100644 index 00000000000..81cee9f6ffe --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_replace_size.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// TensorViews are all based on symbolic sizes. When we first initialize them +// we don't know if they're inputs or outputs which would mean that they have +// runtime shapes. Intermediate tensors (those not going to global memory) do +// not have this information. Since we need to have the correct information in +// the kernel being fetched for shapes, we want to replace input and output +// tensors to reference the runtime structure containing sizes. +void replaceSymbolicSizes(Fusion*); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 8a4f6980e01..ca451ee5f97 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include #include #include @@ -19,19 +17,17 @@ namespace fuser { namespace cuda { void ShiftPredicateInserter::insert( - kir::Expr* expr, + Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, bool within_unswitch) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - kir::TensorView* out_tv = ir_utils::getTVOutput(expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output"); + TensorView* out_tv = ir_utils::getTvOutput(expr); + TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); - TensorView* out_fuser_tv = out_tv->fuserTv(); const bool needs_shift_predicate = - gpu_lower->haloInfo().needsShiftPredicate(out_fuser_tv->definition()); + gpu_lower->haloInfo().needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { return; } @@ -48,12 +44,12 @@ void ShiftPredicateInserter::insert( kir::Predicate* thread_pred_expr = nullptr; if (within_unswitch) { - thread_pred_expr = ir_builder.create(thread_pred); + thread_pred_expr = IrBuilder::create(thread_pred); } kir::Predicate* shift_pred = within_unswitch ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::Shift, expr, thread_pred); // If the expr involves a thread-block barrier, set the predicate of @@ -64,7 +60,7 @@ void ShiftPredicateInserter::insert( return; } - auto shift_ite = ir_builder.create(shift_pred); + auto shift_ite = IrBuilder::create(shift_pred); auto& scope = loops.back()->body(); @@ -83,56 +79,33 @@ void ShiftPredicateInserter::insert( } // Padding by zero - kir::Predicate* padding_pred = ir_builder.create( + kir::Predicate* padding_pred = IrBuilder::create( PredicateType::Padding, expr, thread_pred); - auto bounds_ite = ir_builder.create(padding_pred); + auto bounds_ite = IrBuilder::create(padding_pred); const int pad_value = 0; - auto pad_expr = ir_builder.create( - UnaryOpType::Set, out_tv, ir_builder.create(pad_value)); + auto pad_expr = IrBuilder::create( + UnaryOpType::Set, out_tv, IrBuilder::create(pad_value)); bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); } -AxisHaloInfo::AxisHaloInfo() { - auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - setWidth(0, ir_builder.zeroVal()); - setWidth(1, ir_builder.zeroVal()); -} - -kir::Int* AxisHaloInfo::width() const { - auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - return ir_builder.addExpr(width(0), width(1))->as(); +int AxisHaloInfo::width() const { + return width(0) + width(1); } -kir::Int* AxisHaloInfo::width(int pos) const { +int AxisHaloInfo::width(int pos) const { TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2); - TORCH_INTERNAL_ASSERT(widths_[pos] != nullptr); return widths_[pos]; } -void AxisHaloInfo::setWidth(int pos, kir::Int* width) { +void AxisHaloInfo::setWidth(int pos, int width) { TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2); widths_[pos] = width; } -void AxisHaloInfo::merge(int pos, kir::Int* other) { - auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto cur = width(pos); - kir::Int* new_width = nullptr; - if (cur->isConst() && other->isConst()) { - new_width = ir_builder.create( - std::max(cur->value().value(), other->value().value())); - } else if (cur->isZeroInt()) { - new_width = other; - } else if (other->isZeroInt()) { - new_width = cur; - } else { - new_width = ir_builder.maxExpr(width(pos), other)->as(); - } +void AxisHaloInfo::merge(int pos, int other) { + auto new_width = std::max(width(pos), other); setWidth(pos, new_width); } @@ -144,13 +117,12 @@ void AxisHaloInfo::merge(const AxisHaloInfo& other) { bool AxisHaloInfo::hasHalo() const { return std::any_of( - widths_.begin(), widths_.end(), [](auto w) { return !w->isZeroInt(); }); + widths_.begin(), widths_.end(), [](auto w) { return w != 0; }); } std::string AxisHaloInfo::toString() const { std::stringstream ss; - ss << "<" << kir::toString(width(0)) << ", " << kir::toString(width(1)) - << ">"; + ss << "<" << width(0) << ", " << width(1) << ">"; return ss.str(); } @@ -158,38 +130,21 @@ bool HaloInfo::hasRootAxisInfo(IterDomain* id) const { return root_axis_map_.find(id) != root_axis_map_.end(); } -bool HaloInfo::hasRootAxisInfo(kir::IterDomain* id) const { - return kir_root_axis_map_.find(id) != kir_root_axis_map_.end(); -} - const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const { + // TODO: Enable this check, was failing in many tests + // TORCH_INTERNAL_ASSERT( + // id->definition() == nullptr || id->isRFactorProduct(), + // "Invalid IterDomain: ", + // id); auto it = root_axis_map_.find(id); TORCH_INTERNAL_ASSERT( - it != root_axis_map_.end(), "Halo root axis info not found for ", id); - return it->second; -} - -AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - return const_cast( - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(this)->getRootAxisInfo(id)); -} - -const AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) const { - TORCH_INTERNAL_ASSERT( - id->definition() == nullptr || id->isRFactorProduct(), - "Invalid IterDomain: ", - id); - auto it = kir_root_axis_map_.find(id); - TORCH_INTERNAL_ASSERT( - it != kir_root_axis_map_.end(), + it != root_axis_map_.end(), "Halo root axis info not found for ", - kir::toString(id)); + id->toString()); return it->second; } -AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) { +AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) return const_cast( // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -200,9 +155,6 @@ void HaloInfo::setRootAxisInfo( IterDomain* id, const AxisHaloInfo& root_axis_info) { root_axis_map_[id] = root_axis_info; - kir_root_axis_map_ - [GpuLower::current()->lowerValue(id)->as()] = - root_axis_info; initializeFromRootAxisInfo(id); return; @@ -283,9 +235,6 @@ void HaloInfo::propagateRootAxisInfo( const auto& c_root = consumer->getRootDomain(); - auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - for (const auto i : c10::irange(c_root.size())) { auto c_id = c_root[i]; auto it = c2p.find(c_id); @@ -332,31 +281,19 @@ void HaloInfo::propagateRootAxisInfo( p_info.merge(c_info); } else { int pos = (offset > 0) ? 0 : 1; - p_info.merge( - pos, - ir_builder.addExpr(c_info.width(pos), std::abs(offset)) - ->as()); + p_info.merge(pos, c_info.width(pos) + std::abs(offset)); } } else if (auto gather_op = dynamic_cast(expr)) { - const auto window_dim = - gpu_lower->lowerValue(gather_op->windowShape()[i]); - if (window_dim->isOneInt()) { + const auto window_dim = gather_op->windowShape()[i]; + if (window_dim == 1) { p_info.merge(c_info); continue; } - const auto& pad_dim = gather_op->padWidth()[i]; - const auto pad_dim0 = gpu_lower->lowerValue(pad_dim[0])->as(); - p_info.merge( - 0, ir_builder.addExpr(c_info.width(0), pad_dim0)->as()); + const auto pad_dim0 = gather_op->padWidth()[i][0]; + p_info.merge(0, c_info.width(0) + pad_dim0); // The right-side halo is propagated as: // consumer_right_halo + (window_dim - 1 - left_padding) - p_info.merge( - 1, - ir_builder - .subExpr( - ir_builder.addExpr(c_info.width(1), window_dim), - ir_builder.addExpr(pad_dim0, 1)) - ->as()); + p_info.merge(1, c_info.width(1) + window_dim - 1 - pad_dim0); } else { p_info.merge(c_info); } @@ -390,29 +327,30 @@ void HaloInfo::initializeFromRootAxisInfo(IterDomain* id) { TORCH_INTERNAL_ASSERT(hasRootAxisInfo(id)); auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); const auto& halo_info = getRootAxisInfo(id); auto halo_width = halo_info.width(); if (!halo_info.hasHalo()) { - halo_width_map_[id] = ir_builder.zeroVal(); + setHaloWidth(id, 0); return; } auto expanded_extent = - ir_builder.addExpr(gpu_lower->lowerValue(id->extent()), halo_width); - kir_extent_map_[gpu_lower->lowerValue(id)->as()] = - expanded_extent; + IrBuilder::addExpr(id->extent(), IrBuilder::create(halo_width)); + extent_map_[id] = expanded_extent; halo_width_map_[id] = halo_width; inheritance_map_[id] = {id}; } +void HaloInfo::setHaloWidth(IterDomain* id, int halo_width) { + halo_width_map_[id] = halo_width; +} + // Propagate extent information from root axes to descendants void HaloInfo::build(TensorDomain* td) { auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); auto exprs = DependencyCheck::getAllExprsBetween( {td->getMaybeRFactorDomain().begin(), td->getMaybeRFactorDomain().end()}, @@ -459,33 +397,29 @@ void HaloInfo::build(TensorDomain* td) { auto in_id = split->in(); - const auto& halo_width_it = halo_width_map_.find(in_id); - // If no halo info is found, nothing needs to be done. This ID // must be an ancestor of a domain set by setRootAxisInfo. - if (halo_width_it == halo_width_map_.end()) { + if (!hasHaloWidth(in_id)) { continue; } - const auto halo_width = halo_width_it->second; + const auto halo_width = getHaloWidth(in_id); - if (halo_width->isZeroInt()) { - halo_width_map_.insert({split->outer(), halo_width}); - halo_width_map_.insert({split->inner(), halo_width}); + if (halo_width == 0) { + setHaloWidth(split->outer(), 0); + setHaloWidth(split->inner(), 0); continue; } // propagate to inner domain auto out_id = split->inner(); - auto expanded_extent = ir_builder.addExpr( - gpu_lower->lowerValue(out_id->extent()), halo_width); - kir_extent_map_.insert( - {gpu_lower->lowerValue(out_id)->as(), - expanded_extent}); + auto expanded_extent = + SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width); + extent_map_.insert({out_id, expanded_extent}); - halo_width_map_.insert({split->outer(), ir_builder.zeroVal()}); - halo_width_map_.insert({split->inner(), halo_width}); + setHaloWidth(split->outer(), 0); + setHaloWidth(split->inner(), halo_width); insertToInheritanceMap(td, in_id, split->inner()); } else if (auto merge = dynamic_cast(expr)) { @@ -495,25 +429,24 @@ void HaloInfo::build(TensorDomain* td) { auto outer_extent = getExtent(merge->outer()); if (inner_extent != nullptr || outer_extent != nullptr) { if (inner_extent == nullptr) { - inner_extent = gpu_lower->lowerValue(merge->inner()->extent()); + inner_extent = merge->inner()->extent(); } else { insertToInheritanceMap(td, merge->inner(), merge->out()); } if (outer_extent == nullptr) { - outer_extent = gpu_lower->lowerValue(merge->outer()->extent()); + outer_extent = merge->outer()->extent(); } else { insertToInheritanceMap(td, merge->outer(), merge->out()); } - auto expanded_extent = ir_builder.mulExpr(outer_extent, inner_extent); - kir_extent_map_.insert( - {gpu_lower->lowerValue(merge->out())->as(), - expanded_extent}); + auto expanded_extent = + SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent); + extent_map_.insert({merge->out(), expanded_extent}); // Splitting the output of this merge is not allowed, so // remember it merged_shifted_ids.insert(merge->out()); // Note that halo_width_map_ is not updated } else { - halo_width_map_.insert({merge->out(), ir_builder.zeroVal()}); + setHaloWidth(merge->out(), 0); } } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); @@ -579,7 +512,7 @@ void HaloInfo::validate(TensorView* tv) const { bool shared_mem_needed = false; for (auto use : tv->uses()) { - if (!ir_utils::isTVOp(use)) { + if (!ir_utils::isTvOp(use)) { continue; } if (use->isA() || use->isA()) { @@ -629,21 +562,16 @@ void HaloInfo::validate(TensorView* tv) const { return; } -kir::Val* HaloInfo::getExtent(IterDomain* id) const { - auto kir_id = GpuLower::current()->lowerValue(id)->as(); - return getExtent(kir_id); -} - -kir::Val* HaloInfo::getExtent(kir::IterDomain* id) const { - auto it = kir_extent_map_.find(id); - if (it != kir_extent_map_.end()) { +Val* HaloInfo::getExtent(IterDomain* id) const { + auto it = extent_map_.find(id); + if (it != extent_map_.end()) { return it->second; } else { return nullptr; } } -kir::Int* HaloInfo::getHaloWidth(IterDomain* id) const { +int HaloInfo::getHaloWidth(IterDomain* id) const { auto it = halo_width_map_.find(id); TORCH_INTERNAL_ASSERT(it != halo_width_map_.end()); return it->second; @@ -736,63 +664,11 @@ bool extentCompare( } // namespace bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const { - auto cmp = [](kir::Int* x, kir::Int* y) { - if (x == y) { - return true; - } - auto xv = x->value(); - auto yv = y->value(); - return xv.has_value() && yv.has_value() && xv.value() <= yv.value(); - }; - return extentCompare(*this, id1, id2, cmp); + return extentCompare(*this, id1, id2, std::less_equal<>()); } bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const { - // Returns true only when x and y are proven to be the same. The - // analysis is not comprehensive and can prove in rather trivial - // cases only. Specifically: - // - x and y are the same pointers - // - Both have static values and they are the same - // - Both are defined by the same expression and the inputs are - // proven to be equal - std::function cmp = [&](kir::Int* x, - kir::Int* y) { - if (x == y) { - return true; - } - - auto xv = x->value(); - auto yv = y->value(); - if (xv.has_value() && yv.has_value() && xv.value() == yv.value()) { - return true; - } - - // Check if both are defined by an expression of the same type. If - // so, recursively check the input operands. - auto x_def = x->definition(); - auto y_def = y->definition(); - if (x_def && y_def && - ((x_def->isA() && y_def->isA() && - x_def->as()->operation() == - y_def->as()->operation()) || - (x_def->isA() && y_def->isA() && - x_def->as()->operation() == - y_def->as()->operation()))) { - for (const auto i : c10::irange(x_def->inputs().size())) { - auto x_input = dynamic_cast(x_def->inputs()[i]); - auto y_input = dynamic_cast(y_def->inputs()[i]); - // Both must be kir::Int - TORCH_INTERNAL_ASSERT(x_input && y_input); - if (!cmp(x_input, y_input)) { - return false; - } - } - return true; - } - - return false; - }; - return extentCompare(*this, id1, id2, cmp); + return extentCompare(*this, id1, id2, std::equal_to<>()); } std::string HaloInfo::toString() const { @@ -822,16 +698,19 @@ std::string HaloInfo::toString() const { } bool HaloInfo::needsShiftPredicate(Expr* expr) const { - auto consumer_td = ir_utils::getTVOutput(expr)->domain(); - auto shift_expr = dynamic_cast(expr); - auto gather_expr = dynamic_cast(expr); + // In lowering shift and gather turn into a unary op. We really need the shift + // expr. Do a round about trick to grab it: + auto tv_out = ir_utils::getTvOutput(expr); + auto consumer_td = tv_out->domain(); + auto shift_expr = dynamic_cast(tv_out->definition()); + auto gather_expr = dynamic_cast(tv_out->definition()); for (const auto i : c10::irange(consumer_td->getRootDomain().size())) { auto consumer_id = consumer_td->getRootDomain()[i]; const auto consumer_halo_info = getRootAxisInfo(consumer_id); if (consumer_halo_info.hasHalo() || (shift_expr != nullptr && shift_expr->offset(i) != 0 && !consumer_id->isBroadcast()) || - (gather_expr != nullptr && !gather_expr->windowShape()[i]->isOneInt() && + (gather_expr != nullptr && gather_expr->windowShape()[i] != 1 && !consumer_id->isBroadcast())) { return true; } @@ -839,13 +718,6 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const { return false; } -bool HaloInfo::needsShiftPredicate(kir::Expr* expr) const { - const auto out_tv = expr->outputs()[0]->as(); - auto fuser_expr = out_tv->fuserTv()->definition(); - TORCH_INTERNAL_ASSERT(fuser_expr != nullptr); - return needsShiftPredicate(fuser_expr); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index 378709ca443..c0fea8c1ead 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -16,16 +16,14 @@ namespace cuda { //! Auxiliary class to represent information about halo of an axis class AxisHaloInfo { public: - AxisHaloInfo(); - //! Width of halo. //! //! pos is either 0 or 1. The width of halo at offset zero is set //! when pos is 0. - kir::Int* width(int pos) const; + int width(int pos) const; //! Sum of the widths of both widths - kir::Int* width() const; + int width() const; const auto& widths() const { return widths_; @@ -34,10 +32,10 @@ class AxisHaloInfo { //! Set the halo width of either side. //! pos is either 0 or 1. The width of halo at offset zero is set //! when pos is 0. - void setWidth(int pos, kir::Int* width); + void setWidth(int pos, int width); //! Extend the halo width to account for another axis. - void merge(int pos, kir::Int* other); + void merge(int pos, int other); //! Extend the halo width to account for another axis. void merge(const AxisHaloInfo& other); @@ -53,7 +51,7 @@ class AxisHaloInfo { //! widths_[0] is non-zero and designates the size of the //! halo. Similarly, non-zero widths_[1] means the axis has halo at //! the other end of the axis. - std::array widths_ = {nullptr, nullptr}; + std::array widths_ = {0, 0}; }; //! Helper class for lowering tensors with halo. Only valid at the @@ -77,7 +75,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! Returns true if id has the root halo information set by //! setRootAxisInfo. bool hasRootAxisInfo(IterDomain* id) const; - bool hasRootAxisInfo(kir::IterDomain* id) const; //! Returns the registed AxisHaloInfo of a root axis. //! @@ -85,9 +82,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! non-root axes. const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const; AxisHaloInfo& getRootAxisInfo(IterDomain* id); - //! KIR version - const AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id) const; - AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id); //! Query if an axis has a halo width. //! @@ -98,12 +92,11 @@ class TORCH_CUDA_CU_API HaloInfo { //! //! It's an error if queried for an axis with no halo width //! information. - kir::Int* getHaloWidth(IterDomain* id) const; + int getHaloWidth(IterDomain* id) const; //! Returns an extent if id is extended for halo. Nullptr is //! returned otherwise. - kir::Val* getExtent(IterDomain* id) const; - kir::Val* getExtent(kir::IterDomain* id) const; + Val* getExtent(IterDomain* id) const; //! Returns all child domains of a root domain that inherits the //! halo of the root domain. @@ -135,7 +128,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! interior and another for padding. Predicate insertion is done in //! the ShiftPredicateInserter class below. bool needsShiftPredicate(Expr* expr) const; - bool needsShiftPredicate(kir::Expr* expr) const; std::string toString() const; @@ -166,14 +158,14 @@ class TORCH_CUDA_CU_API HaloInfo { //! Validate shift usage void validate(TensorView* td) const; + void setHaloWidth(IterDomain* id, int halo_width); + private: //! Halo information of root axes std::unordered_map root_axis_map_; - //! KIR version - std::unordered_map kir_root_axis_map_; //! Halo-extended extents. No mapping for axes without halo extension - std::unordered_map kir_extent_map_; + std::unordered_map extent_map_; //! The halo width of an axis. //! @@ -209,7 +201,7 @@ class TORCH_CUDA_CU_API HaloInfo { //! inner axis is merged with another axis of extent M, we know that //! the extent of the resulting output axis is 5*M, but we don't //! create its mapping. - std::unordered_map halo_width_map_; + std::unordered_map halo_width_map_; //! Mappings from root domains to child domains that inherit halo std::unordered_map> @@ -224,9 +216,9 @@ class ShiftPredicateInserter { //! the usual predicated expression, so the insertion is also done //! here. static void insert( - kir::Expr* expr, + Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, bool within_unswitch); }; diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index a7f8768883d..8721490feb7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -17,55 +16,49 @@ namespace cuda { namespace { -kir::Bool* getPredicatePerParallelType( +Bool* getPredicatePerParallelType( ParallelType pt, const ThreadPredicateMap::PredicateInfo& pred_info) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); // If pt is not used or is proven to be one, no need to predicate. if (pt_dim == nullptr || pt_dim->isOneInt()) { - return ir_builder.trueVal(); + return GpuLower::current()->kernel()->trueVal(); } - // When BID needs to be predicated, that means it's an output of a grid // reduction and only the last block index in that dimension has the right // value from the grid reduce. if (isParallelTypeBlockDim(pt) && pred_info.limited_types.get(pt)) { - return ir_builder - .eqExpr( - kir::NamedScalar::getParallelIndex(pt), - ir_builder.subExpr( - kir::NamedScalar::getParallelDim(pt), ir_builder.oneVal())) - ->as(); + return SimplifyingIrBuilder::eqExpr( + NamedScalar::getParallelIndex(pt), + SimplifyingIrBuilder::subExpr( + NamedScalar::getParallelDim(pt), + GpuLower::current()->kernel()->oneVal())) + ->as(); } // Otherwise, only thread of index 0 executes the computation - return ir_builder - .eqExpr(kir::NamedScalar::getParallelIndex(pt), ir_builder.zeroVal()) - ->as(); + return SimplifyingIrBuilder::eqExpr( + NamedScalar::getParallelIndex(pt), + GpuLower::current()->kernel()->zeroVal()) + ->as(); } } // namespace -kir::Bool* ThreadPredicateMap::getPredicateFromPredicateInfo( +Bool* ThreadPredicateMap::getPredicateFromPredicateInfo( const ThreadPredicateMap::PredicateInfo& pred_info) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - const auto pred_types = pred_info.limited_types | pred_info.redundant_types; if (pred_types.none()) { - return ir_builder.trueVal(); + return GpuLower::current()->kernel()->trueVal(); } - kir::Bool* pred = nullptr; - + Bool* pred = nullptr; for (const auto pt : pred_types) { const auto tp = getPredicatePerParallelType(pt, pred_info); - pred = ir_builder.andExpr(pred, tp)->as(); + pred = SimplifyingIrBuilder::andExpr(pred, tp)->as(); } - TORCH_INTERNAL_ASSERT(pred != nullptr); return pred; @@ -191,7 +184,9 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { if (id->isReduction()) { id_reductions.set(id->getParallelType()); } - if (id->isBroadcast()) { + if (id->isBroadcast() && + GpuLower::current()->concretizedBroadcastDomains().isConcretized( + id)) { id_bcasts.set(id->getParallelType()); } } @@ -302,7 +297,7 @@ void ThreadPredicateMap::insert( thread_predicates_.insert({tv, pred_info}); } -kir::Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const { +Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const { TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find ", tv); auto pred_info = getPredicateInfo(tv); return getPredicateFromPredicateInfo(pred_info); @@ -326,7 +321,8 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains( const bool output_smem = tv->getMemoryType() == MemoryType::Shared; for (auto id : iter_domains) { - if (!id->isBroadcast()) { + if (!id->isBroadcast() || + !GpuLower::current()->concretizedBroadcastDomains().isConcretized(id)) { continue; } if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h index 256e0385aeb..0d7a2685b32 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h @@ -1,7 +1,7 @@ #pragma once -#include +#include #include #include @@ -69,7 +69,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { ParallelTypeBitmap getPredicatedParallelTypes(const TensorView* tv) const; //! Returns a Bool predicate for a given TensorView. - kir::Bool* getPredicate(const TensorView* tv) const; + Bool* getPredicate(const TensorView* tv) const; //! Returns a ParallelTypeBitmap representing which domain needs //! blockBroadcast. @@ -81,7 +81,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { void print() const; //! Generate a Bool value from PredicateInfo. - static kir::Bool* getPredicateFromPredicateInfo( + static Bool* getPredicateFromPredicateInfo( const ThreadPredicateMap::PredicateInfo& pred_info); private: diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp new file mode 100644 index 00000000000..ab62530591a --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp @@ -0,0 +1,119 @@ +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +void ConcretizedBroadcastDomains::build(Fusion* fusion) { + // Initialize the origin map with input broadcast domains + for (const auto fusion_input_tv : + ir_utils::filterByType(fusion->inputs())) { + for (auto root_id : fusion_input_tv->getRootDomain()) { + if (root_id->isBroadcast()) { + broadcast_origin_map_.emplace( + root_id, std::unordered_set({root_id})); + } + } + } + traverse(fusion); +} + +bool ConcretizedBroadcastDomains::isConcretized(IterDomain* id) const { + auto it = concretized_domains_.find(id); + return it != concretized_domains_.end(); +} + +void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) { + // Create a new entry for each of new broadcast domains + auto out = bop->out()->as(); + for (const auto i : c10::irange(out->getRootDomain().size())) { + if (bop->getBroadcastDimFlags().at(i)) { + auto new_bcast_id = out->getRootDomain().at(i); + broadcast_origin_map_.emplace( + new_bcast_id, std::unordered_set({new_bcast_id})); + } + } +} + +void ConcretizedBroadcastDomains::handle(Expr* expr) { + IterVisitor::handle(expr); + + // Propagate broadcast origin info from producers to consumers + for (auto producer : ir_utils::filterByType(expr->inputs())) { + std::unordered_set producer_broadcasts; + // This assumes there's no merged broadcast axes between root and rfactor + // domains which is not possible at the moment. If this assumption is ever + // invalidated we would need to manaually propagate root IDs to rfactor IDs. + for (auto producer_id : producer->getMaybeRFactorDomain()) { + if (producer_id->isBroadcast()) { + producer_broadcasts.insert(producer_id); + } + } + if (producer_broadcasts.empty()) { + continue; + } + + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + auto p2c_map = + PairwiseRootDomainMap(producer, consumer) + .mapProducerToConsumer( + producer->domain(), consumer->domain(), producer_broadcasts); + for (const auto& kv : p2c_map) { + auto p_id = kv.first; + auto c_id = kv.second; + const bool is_concretized = !c_id->isBroadcast(); + auto it = broadcast_origin_map_.find(p_id); + TORCH_INTERNAL_ASSERT( + it != broadcast_origin_map_.end(), + "Broadcast origin info not found for producer broadcast domain: ", + p_id->toString(), + " of ", + producer->toString()); + const auto& producer_origins = it->second; + if (is_concretized) { + // Keep track of all the origin domains as concretized + for (auto origin : producer_origins) { + // concretized_root_domains_.insert(origin); + markAsConcretized(origin); + } + } else { + // Not concretized yet. Propagate forward the origin info. + auto& consumer_origins = broadcast_origin_map_[c_id]; + for (auto origin : producer_origins) { + consumer_origins.insert(origin); + } + consumer_origins.insert(c_id); + } + } + } + } +} + +void ConcretizedBroadcastDomains::markAsConcretized(IterDomain* root_domain) { + std::deque child_domains({root_domain}); + while (!child_domains.empty()) { + auto child = child_domains.front(); + child_domains.pop_front(); + if (!concretized_domains_.emplace(child).second) { + continue; + } + const auto& child_uses = child->uses(); + for (auto child_use : child_uses) { + for (auto out_id : + ir_utils::filterByType(child_use->outputs())) { + child_domains.push_back(out_id); + } + } + } +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h new file mode 100644 index 00000000000..9dd50e8afc1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! Traverse and collect all concretized broadcast domains. +//! +//! The traversal first initializes the origin map with broadcast +//! domains in input tensors. Then, a new entry is added to the origin +//! map when a broadcast op is encountered during a forward traversal +//! of the given fusion. For non-broadcast ops, mappings are just +//! propagated forward using PairwiseRootDomainMap. +//! +//! When the mapped consumer domain is not broadcast, it means the +//! producer broadcast domain is concretized, and its origin broadcast +//! domains are marked as concretized. +class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor { + public: + void build(Fusion* fusion); + + bool isConcretized(IterDomain* id) const; + + private: + using IterVisitor::handle; + + void handle(BroadcastOp* bop) final; + + void handle(Expr* expr) final; + + void markAsConcretized(IterDomain* root_domain); + + private: + //! Maps each broadcast domain to its original broadcast + //! domains. Their can be multiple original domains due to, e.g., + //! binary ops with broadcast domains in both inputs. + std::unordered_map> + broadcast_origin_map_; + //! Set of all concretized original domains + std::unordered_set concretized_domains_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp index 33651785d43..a8905b4d404 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp @@ -74,7 +74,7 @@ bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) { } // namespace -void TrivialReductionInfo::build(Fusion* fusion, GpuLower* gpu_lower) { +void TrivialReductionInfo::build(Fusion* fusion) { auto used_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(used_vals)) { @@ -99,20 +99,6 @@ void TrivialReductionInfo::build(Fusion* fusion, GpuLower* gpu_lower) { } } } - - buildKir(fusion, gpu_lower); -} - -void TrivialReductionInfo::buildKir(Fusion* fusion, GpuLower* gpu_lower) { - for (auto id : domains_) { - auto kir_trivial_id = gpu_lower->lowerValue(id)->as(); - kir_domains_.insert(kir_trivial_id); - } - - for (auto id : domains_derived_from_root_) { - auto kir_trivial_id = gpu_lower->lowerValue(id)->as(); - kir_domains_derived_from_root_.insert(kir_trivial_id); - } } bool TrivialReductionInfo::isDerived(IterDomain* id) const { @@ -124,15 +110,6 @@ bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const { domains_derived_from_root_.end(); } -bool TrivialReductionInfo::isDerived(kir::IterDomain* id) const { - return kir_domains_.find(id) != kir_domains_.end(); -} - -bool TrivialReductionInfo::isDerivedFromRoot(kir::IterDomain* id) const { - return kir_domains_derived_from_root_.find(id) != - kir_domains_derived_from_root_.end(); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h index c16439ed4f0..9ccbc2f7828 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -13,24 +13,15 @@ namespace jit { namespace fuser { namespace cuda { -class GpuLower; - //! Detect almost all IterDomains that are derived from trivial //! reductons. class TORCH_CUDA_CU_API TrivialReductionInfo { public: - void build(Fusion* fusion, GpuLower* gpu_lower); + void build(Fusion* fusion); bool isDerived(IterDomain* id) const; bool isDerivedFromRoot(IterDomain* id) const; - bool isDerived(kir::IterDomain* id) const; - bool isDerivedFromRoot(kir::IterDomain* id) const; - - private: - //! Convert the sets to KIR sets - void buildKir(Fusion* fusion, GpuLower* gpu_lower); - private: //! IterDomains that are derived only from trivial //! reductons. Included domains are not limited to reduction axes as @@ -48,9 +39,6 @@ class TORCH_CUDA_CU_API TrivialReductionInfo { //! trivial reductions. These domains do not need to manifest as //! for-loops. std::unordered_set domains_derived_from_root_; - - std::unordered_set kir_domains_; - std::unordered_set kir_domains_derived_from_root_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 08f91ba59bd..c4f926131a8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include #include @@ -22,8 +20,7 @@ namespace { // Provide a new for loop matching the one provided kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) { - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - const auto new_loop = ir_builder.create(for_loop); + const auto new_loop = IrBuilder::create(for_loop); for (auto expr : for_loop->body().exprs()) { if (auto nested_for_loop = dynamic_cast(expr)) { expr = cloneLoopNest(nested_for_loop); @@ -35,20 +32,20 @@ kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) { // Returns true if expr is an expression that initializes a reduction // buffer. -bool isReductionInitExpr(const kir::Expr* expr) { +bool isReductionInitExpr(const Expr* expr) { // False if its output isn't a TensorView - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { return false; } // False if it doesn't have any reduction axis - const auto out_tv = expr->outputs()[0]->as(); + const auto out_tv = expr->outputs()[0]->as(); if (!out_tv->domain()->hasReduction()) { return false; } // False if it has have TensorView inputs as initialization should // never use TensorViews const auto tv_filter_inp_view = - ir_utils::filterByType(expr->inputs()); + ir_utils::filterByType(expr->inputs()); if (tv_filter_inp_view.begin() != tv_filter_inp_view.end()) { return false; } @@ -57,28 +54,27 @@ bool isReductionInitExpr(const kir::Expr* expr) { } // namespace -void UnrollPass::handle(kir::Expr* expr) { - if (ir_utils::isTVOp(expr)) { +void UnrollPass::handle(Expr* expr) { + if (ir_utils::isTvOp(expr)) { // If tv op, predicate it - const auto out_tv = ir_utils::getTVOutput(expr); + const auto out_tv = ir_utils::getTvOutput(expr); const bool should_predicate = !for_loops_.empty() || - out_tv->memoryType() == MemoryType::Global || - out_tv->memoryType() == MemoryType::Shared; + out_tv->getMemoryType() == MemoryType::Global || + out_tv->getMemoryType() == MemoryType::Shared; if (!should_predicate) { return; } - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); const auto thread_pred = isReductionInitExpr(expr) - ? ir_builder.trueVal() - : GpuLower::current()->threadPredMap().getPredicate(out_tv->fuserTv()); + ? GpuLower::current()->kernel()->trueVal() + : GpuLower::current()->threadPredMap().getPredicate(out_tv); // When this expr is in an unswitched block, only attach the // thread predicate to the expr as thread predicates are not // grouped to the unswitch predicate. kir::Predicate* thread_pred_expr = nullptr; if (unswitched_loop_) { - thread_pred_expr = ir_builder.create(thread_pred); + thread_pred_expr = IrBuilder::create(thread_pred); } non_trivial_pred_found_ = true; @@ -95,7 +91,7 @@ void UnrollPass::handle(kir::Expr* expr) { if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) { const auto write_pred = unswitched_loop_ ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::ReductionWrite, expr, thread_pred); expr->setWritePredicate(write_pred); } @@ -105,7 +101,7 @@ void UnrollPass::handle(kir::Expr* expr) { if (ir_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { const auto pred = unswitched_loop_ ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::Inline, expr, thread_pred); expr->setPredicate(pred); return; @@ -116,28 +112,28 @@ void UnrollPass::handle(kir::Expr* expr) { if (!unswitched_loop_ && std::any_of( for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) { - return fl->iter_domain()->parallelType() == + return fl->iter_domain()->getParallelType() == ParallelType::Vectorize; })) { - pred = ir_builder.create(PredicateType::Vectorize); + pred = IrBuilder::create(PredicateType::Vectorize); } if (pred == nullptr) { pred = unswitched_loop_ ? thread_pred_expr - : ir_builder.create( + : IrBuilder::create( PredicateType::Inline, expr, thread_pred); } // If we need a predicate, put expr inside an if then else - kir::IfThenElse* inline_ite = ir_builder.create(pred); + kir::IfThenElse* inline_ite = IrBuilder::create(pred); if (for_loops_.empty()) { // Special handling for top level output expressions that still // need predicates. One motivating example is a reduction op that // reduces to a scalar (issue #491) - expr_replacement_map_.insert({expr, inline_ite}); + kir::ExprMutator::registerReplace(expr, inline_ite, nullptr); } else { - for_loops_.back()->body().insert_before(expr, inline_ite); - for_loops_.back()->body().erase(expr); + kir::ExprMutator::registerReplace( + expr, inline_ite, &for_loops_.back()->body()); } inline_ite->thenBody().push_back(expr); } else if (auto for_loop = dynamic_cast(expr)) { @@ -150,8 +146,8 @@ void UnrollPass::handle(kir::Expr* expr) { void UnrollPass::handle(kir::ForLoop* fl) { // Setup for loop scoping const bool is_unroll = - fl->iter_domain()->parallelType() == ParallelType::Unroll || - fl->iter_domain()->parallelType() == ParallelType::Unswitch; + fl->iter_domain()->getParallelType() == ParallelType::Unroll || + fl->iter_domain()->getParallelType() == ParallelType::Unswitch; // If we're not looking for an unroll loop, or didn't find one, process as // normal. @@ -172,10 +168,9 @@ void UnrollPass::handle(kir::ForLoop* fl) { return; } - kir::IrBuilder ir_builder(GpuLower::current()->kernel()); - auto unroll_pred = ir_builder.create(fl); + auto unroll_pred = IrBuilder::create(fl); - kir::IfThenElse* unroll_ite = ir_builder.create(unroll_pred); + kir::IfThenElse* unroll_ite = IrBuilder::create(unroll_pred); // Get the loop nest for the unrolled path kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl); @@ -199,12 +194,18 @@ void UnrollPass::handle(kir::ForLoop* fl) { handle(inlined_loop); look_for_unroll_ = true; if (!non_trivial_pred_found_) { - expr_replacement_map_.insert({fl, inlined_loop}); + kir::ExprMutator::registerReplace( + fl, + inlined_loop, + for_loops_.empty() ? nullptr : &for_loops_.back()->body()); } else { if (!canOmitElseClause(fl)) { unroll_ite->elseBody().push_back(inlined_loop); } - expr_replacement_map_.insert({fl, unroll_ite}); + kir::ExprMutator::registerReplace( + fl, + unroll_ite, + for_loops_.empty() ? nullptr : &for_loops_.back()->body()); } } @@ -221,14 +222,14 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { // If there's any expression that requires barrier // synchronization, the else part can't be omitted for (auto expr : loop->body().exprs()) { - if (expr->isA()) { + if (expr->isA()) { const ParallelTypeBitmap domains = pred_map.getParallelBroadcastDomains( - expr->outputs()[0]->as()->fuserTv()); + expr->outputs()[0]->as()); if (domains.any()) { return false; } - } else if (expr->isA() || expr->isA()) { - auto td = ir_utils::getTVOutput(expr)->domain(); + } else if (expr->isA() || expr->isA()) { + auto td = ir_utils::getTvOutput(expr)->domain(); if (td->hasBlockReduction() || td->hasGridReduction()) { return false; } @@ -238,14 +239,14 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { // unswitch predicate is sufficient. // When the loop stop is the same as the extent of its IterDomain, // the per-thread visit count is guaranteed to be one at most (see - // CudaKernelGenerator::visit(kir::ForLoop*) as well. Also, when a + // CudaKernelGenerator::handle(kir::ForLoop*) as well. Also, when a // loop is vectorized (not misaligned), the count must be one at // most. Even if not parallelized nor vectoirzed, it is also // sufficient if the loop stop is in fact one. bool visit_once = false; auto id = loop->iter_domain(); if ((id->isThread() && (loop->stop() == id->extent())) || - id->parallelType() == ParallelType::Vectorize) { + id->getParallelType() == ParallelType::Vectorize) { visit_once = true; } if (!visit_once) { @@ -273,30 +274,18 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { } // Generate the loop nest structure and place it in lowered_exprs -UnrollPass::UnrollPass(const std::vector& exprs) { +UnrollPass::UnrollPass(const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::computeMap"); - - // Run through loop nests and further lower the expressions - for (auto* expr : exprs) { - handle(expr); - } + kir::ExprMutator::traverseAndInsert(exprs); } -std::vector UnrollPass::runPass( +std::vector UnrollPass::runPass( Fusion* fusion, - const std::vector& exprs) { + const std::vector& exprs) { FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::runPass"); UnrollPass unroll_pass(exprs); - - std::vector mutated_exprs; - mutated_exprs.reserve(exprs.size()); - for (auto expr : exprs) { - mutated_exprs.push_back( - ir_utils::applyReplacements(unroll_pass.replacementMap(), expr)); - } - - return mutated_exprs; + return unroll_pass.exprs_; } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index bec4966dd94..14725c405b7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include #include #include #include @@ -51,33 +52,32 @@ namespace cuda { //! predicate still in the inner most loop, making sure that we cover edges and //! corners. //! -class TORCH_CUDA_CU_API UnrollPass { +class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { public: // Take the incoming exprs and run loop unrolling, returning the new IR - static std::vector runPass( + static std::vector runPass( Fusion* fusion, - const std::vector& exprs); + const std::vector& exprs); static bool canOmitElseClause(kir::ForLoop* fl); private: // Generate the for Expr replacement map - UnrollPass(const std::vector& exprs); + UnrollPass(const std::vector& exprs); - const std::unordered_map& replacementMap() const { + const std::unordered_map& replacementMap() const { return expr_replacement_map_; } - void handle(kir::ForLoop* fl); + using OptOutDispatch::handle; - void handle(kir::Expr* expr); + void handle(kir::ForLoop* fl) final; + + void handle(Expr* expr) final; private: // We will track which loops in the incoming IR will be replaced and by what - std::unordered_map expr_replacement_map_; - - // Keep all for loops conveniently to make unrolling easier - std::vector for_loops_; + std::unordered_map expr_replacement_map_; // keep track if we're within an unrolled loop bool look_for_unroll_ = true; diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 5d015c450d9..ba2f618efae 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include #include @@ -23,38 +21,14 @@ namespace cuda { namespace scope_utils { -std::vector getLoops(kir::Expr* scope) { - std::vector loops; - while (scope != nullptr) { - if (auto loop = dynamic_cast(scope)) { - loops.push_back(loop); - } - scope = scope->parentScope(); - } - std::reverse(loops.begin(), loops.end()); - return loops; -} - -void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr) { - if (auto ite = dynamic_cast(scope)) { - ite->thenBody().insert_before(ref, expr); - } else if (auto for_loop = dynamic_cast(scope)) { - for_loop->body().insert_before(ref, expr); - } else { - TORCH_INTERNAL_ASSERT(false, "Unexpected scope expression"); - } -} - //! Create an **empty** Forloop and copy the metadata. -kir::ForLoop* cloneForLoop(kir::IrBuilder& ir_builder, kir::ForLoop* for_loop) { - return ir_builder.create(for_loop); +kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop) { + return IrBuilder::create(for_loop); } //! Create an **empty** IfThenElse and copy the metadata. -kir::IfThenElse* cloneIfThenElse( - kir::IrBuilder& ir_builder, - kir::IfThenElse* ite) { - return ir_builder.create(ite->predicate()); +kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite) { + return IrBuilder::create(ite->predicate()); } } // namespace scope_utils @@ -103,17 +77,18 @@ std::vector iterDomainInputsOfOrderedAs( } bool isTV(const Val* val) { - return val->getValType().value() == ValType::TensorView; + return val->getValType().value() == ValType::TensorView || + val->getValType().value() == ValType::TensorIndex; } // Check if we're a TensorView op that we can generate code for. -bool isTVOp(const Expr* expr) { +bool isTvOp(const Expr* expr) { if (std::any_of( expr->outputs().begin(), expr->outputs().end(), [](Val* v) { return isTV(v); }) && - (expr->getExprType().value() == ExprType::BinaryOp || - expr->getExprType().value() == ExprType::UnaryOp || + (expr->getExprType().value() == ExprType::UnaryOp || + expr->getExprType().value() == ExprType::BinaryOp || expr->getExprType().value() == ExprType::TernaryOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::WelfordOp || @@ -121,28 +96,26 @@ bool isTVOp(const Expr* expr) { expr->getExprType().value() == ExprType::TransposeOp || expr->getExprType().value() == ExprType::ShiftOp || expr->getExprType().value() == ExprType::GatherOp || - expr->getExprType().value() == ExprType::ViewOp)) { + expr->getExprType().value() == ExprType::ViewOp || + expr->getExprType().value() == ExprType::GridReduction || + expr->getExprType().value() == ExprType::GridBroadcast || + expr->getExprType().value() == ExprType::GridWelford)) { return true; } return false; } -bool isTVOp(const kir::Expr* expr) { - const auto& outputs = expr->outputs(); - return outputs.size() >= 1 && outputs[0]->isA(); -} - -kir::TensorView* getTv(kir::Val* val) { - if (auto tv = dynamic_cast(val)) { - return tv; - } else if (auto ti = dynamic_cast(val)) { - return ti->view(); +TensorView* getTv(Val* val) { + if (val->isA()) { + return val->as(); + } else if (val->isA()) { + return val->as()->view(); } return nullptr; } -std::vector getTvs(const std::vector& vals) { - std::vector tvs; +std::vector getTvs(const std::vector& vals) { + std::vector tvs; for (auto val : vals) { auto tv = ir_utils::getTv(val); if (tv) { @@ -152,32 +125,7 @@ std::vector getTvs(const std::vector& vals) { return tvs; } -kir::TensorView* asTv(kir::Val* val) { - auto tv = getTv(val); - TORCH_INTERNAL_ASSERT(tv != nullptr, "Neigher TensorView nor TensorIndex"); - return tv; -} - -std::vector asTvs(const std::vector vals) { - std::vector tvs; - for (auto val : vals) { - auto tv = ir_utils::asTv(val); - tvs.emplace_back(tv); - } - return tvs; -} - -// TODO: why do we assume there's a single TV output? -TensorView* getTVOutput(const Expr* expr) { - for (auto out : expr->outputs()) { - if (out->getValType().value() == ValType::TensorView) { - return out->as(); - } - } - return nullptr; -} - -kir::TensorView* getTVOutput(const kir::Expr* expr) { +TensorView* getTvOutput(const Expr* expr) { for (auto out : expr->outputs()) { if (auto tv = getTv(out)) { return tv; @@ -193,25 +141,20 @@ bool isScalarOp(const Expr* expr) { return true; } -Expr* asExpr(Statement* stmt) { - TORCH_INTERNAL_ASSERT(stmt->isExpr()); - return stmt->as(); -} - -TensorView* asTV(Val* val) { - TORCH_INTERNAL_ASSERT(isTV(val)); - return val->as(); -} - bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { - if (!isTVOp(expr)) { + if (!isTvOp(expr)) { + return false; + } + + if (!(expr->isA() || expr->isA() || + expr->isA() || expr->isA() || + expr->isA() || expr->isA())) { return false; } - auto tv = getTVOutput(expr); + auto tv = getTvOutput(expr); - if ((expr->isA() || expr->isA()) && - (tv->hasBlockReduction() || tv->hasGridReduction())) { + if (tv->hasBlockReduction() || tv->hasGridReduction()) { return true; } else if (expr->isA()) { const ParallelTypeBitmap pt_map = @@ -222,64 +165,22 @@ bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { return false; } -bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map) { - if (expr->isA() || expr->isA() || - expr->isA() || expr->isA() || - expr->isA() || expr->isA()) { - auto fuser_tv = getTVOutput(expr)->fuserTv(); - auto fuser_expr = fuser_tv->definition(); - TORCH_INTERNAL_ASSERT(fuser_expr != nullptr); - return hasBlockSync(fuser_expr, pred_map); - } - - return false; -} - -kir::Expr* applyReplacements( - const std::unordered_map& expr_replacement_map, - kir::Expr* expr) { - auto handle_scope = [&](kir::Scope& scope) { - for (const auto i : c10::irange(scope.size())) { - scope[i] = applyReplacements(expr_replacement_map, scope[i]); - } - }; - - const auto it = expr_replacement_map.find(expr); - if (it != expr_replacement_map.end()) { - return it->second; - } else { - if (auto for_loop = dynamic_cast(expr)) { - handle_scope(for_loop->body()); - } else if (auto ite = dynamic_cast(expr)) { - handle_scope(ite->thenBody()); - handle_scope(ite->elseBody()); - } - return expr; - } -} - -c10::optional getMaybeWarpReductionDim( - const kir::ReductionOp* node) { - auto kir_tv = ir_utils::getTVOutput(node); - if (!kir_tv) { +c10::optional getMaybeWarpReductionDim(const ReductionOp* node) { + auto tv_out = getTv(node->out()); + if (tv_out == nullptr) { return c10::nullopt; } - auto fuser_reduction = kir_tv->fuserTv()->definition()->as(); - return getMaybeWarpReductionDim(fuser_reduction); -} -c10::optional getMaybeWarpReductionDim(const ReductionOp* node) { - auto fuser_tv_out = node->out()->as(); - auto fuser_tv_in = node->in()->as(); + auto tv_in = getTv(node->in()); // only support reducing to registers for now. - if (fuser_tv_in->getMemoryType() != MemoryType::Local || - fuser_tv_out->getMemoryType() != MemoryType::Local) { + if (tv_in->getMemoryType() != MemoryType::Local || + tv_out->getMemoryType() != MemoryType::Local) { return c10::nullopt; } IterDomain* reduction_on_xdim = nullptr; - for (auto id : fuser_tv_out->domain()->domain()) { + for (auto id : tv_out->domain()->domain()) { // Currently warp reduction only allows // serial and block.x parallel reductions if (id->isReduction() && id->isParallelized()) { @@ -302,7 +203,7 @@ c10::optional getMaybeWarpReductionDim(const ReductionOp* node) { return c10::optional(reduction_on_xdim); } - if (reduction_on_xdim->extent()->isConstScalar()) { + if (reduction_on_xdim->extent()->isConst()) { auto extent_value = reduction_on_xdim->extent()->getInt().value(); if (extent_value % at::cuda::warp_size() == 0) { return c10::optional(reduction_on_xdim); @@ -329,22 +230,22 @@ bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis) { }); } -std::unordered_map getParallelDomains( - kir::Val* val) { - kir::TensorView* kir_tv = nullptr; - if (val->isA()) { - kir_tv = val->as(); +std::unordered_map getParallelDomains( + Val* val) { + TensorView* tv = nullptr; + if (val->isA()) { + tv = val->as(); } else if (val->isA()) { - kir_tv = val->as()->view(); + tv = val->as()->view(); } else { TORCH_INTERNAL_ASSERT( false, "Provided val is not TensorIndex or TensorView."); } - std::unordered_map parallel_domains; - for (auto d : kir_tv->domain()->domain()) { + std::unordered_map parallel_domains; + for (auto d : tv->domain()->domain()) { if (d->isThread()) { - parallel_domains.insert(std::make_pair(d->parallelType(), d)); + parallel_domains.insert(std::make_pair(d->getParallelType(), d)); } } return parallel_domains; @@ -354,29 +255,60 @@ std::unordered_map getParallelDomains( namespace loop_utils { -// TODO: Clean this up, Naoya added a mechanism we should be able to reuse. -std::pair getAllocPoint( +BasicAllocInfo getAllocInformation( const TensorView* tv, - const std::vector& loops, + const std::vector& for_loops, const std::unordered_map& id_map, bool use_id_map) { - const auto gpu_lower = GpuLower::current(); + BasicAllocInfo info; + auto gpu_lower = GpuLower::current(); + const auto& loop_map = gpu_lower->caLoopMap(); - // If in global memory, it can be all the way outside the loops. - if (tv->getMemoryType() == MemoryType::Global) { - return {nullptr, 0}; - } + bool outer_alloc_found = false; + + for (auto fl : for_loops) { + if (info.alloc_pos == tv->getComputeAtPosition()) { + break; + } - // Figure out where we want to place alloc/reduction initialization. We want - // outside an unroll loop, or inside our computeAt point. - kir::ForLoop* alloc_loop = nullptr; + if (tv->axis(info.alloc_pos)->isReduction()) { + const auto outputs = FusionGuard::getCurFusion()->getTerminatingOutputs(); + TORCH_INTERNAL_ASSERT( + std::find(outputs.begin(), outputs.end(), tv) != outputs.end(), + "Invalid computeAt of T", + tv->name(), + ". A reducation axis is detected outside computeAt point even though it is not an output tensor."); + break; + } + + auto fl_id = fl->iter_domain(); + + if (fl_id->getParallelType() == ParallelType::Unroll) { + break; + } + + // Shared memory must be allocated outside of unswitched + // domains. See issue #1133. + if (fl_id->getParallelType() == ParallelType::Unswitch && + tv->getMemoryType() == MemoryType::Shared) { + outer_alloc_found = true; + } + + // Assume global memory is allocated at outer most scope. + if (tv->getMemoryType() == MemoryType::Global) { + outer_alloc_found = true; + } - auto loops_it = loops.begin(); - // Look at each axis individually in out's domain - for (const auto tv_i : c10::irange((int64_t)tv->getComputeAtPosition())) { - // Grab the axis ID + // Allocation of a double buffered tensor is placed outside its + // double buffer axis. + if (tv->isDoubleBuffered() && + tv->axis(info.alloc_pos) == + gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) { + outer_alloc_found = true; + } + + auto local_id = tv->axis(info.alloc_pos); - auto local_id = tv->axis(tv_i); if (use_id_map) { auto id_it = id_map.find(local_id); if (id_it != id_map.end()) { @@ -384,52 +316,33 @@ std::pair getAllocPoint( } } - if (gpu_lower->trivialReductionInfo().isDerivedFromRoot(local_id)) { - continue; + if (loop_map.areMapped(local_id, fl_id)) { + info.alloc_pos++; } - auto lowered_local_id = - gpu_lower->lowerValue(local_id)->as(); - loops_it = std::find_if( - loops_it, loops.end(), [&lowered_local_id](const auto& loop) { - return GpuLower::current()->caLoopMap().areMapped( - lowered_local_id, loop->iter_domain()) || - loop->iter_domain()->parallelType() == ParallelType::Unroll; - }); + info.init_for_loop = fl; - TORCH_INTERNAL_ASSERT( - loops_it != loops.end(), - "Could not find all required axes for indexing when trying to index into ", - tv); - if ((*loops_it)->iter_domain()->parallelType() == ParallelType::Unroll) { - return {alloc_loop, tv_i}; + if (!outer_alloc_found) { + info.alloc_for_loop = fl; } - - alloc_loop = *loops_it; - ++loops_it; } - return {alloc_loop, (int64_t)tv->getComputeAtPosition()}; -} - -std::pair getAllocPoint( - const TensorView* tv, - const std::vector& loops) { - return getAllocPoint(tv, loops, {}, false); + return info; } } // namespace loop_utils namespace { -class ReplaceExprInput : public kir::MutableIrVisitor { +class ReplaceExprInput : public OptOutDispatch { public: - static kir::Expr* replace( - kir::Expr* expr, - const std::unordered_map& replacement_map) { + using OptOutDispatch::handle; + static Expr* replace( + Expr* expr, + const std::unordered_map& replacement_map) { ReplaceExprInput replacer(expr, replacement_map); TORCH_INTERNAL_ASSERT(expr != nullptr); - expr->accept(&replacer); + replacer.handle(expr); TORCH_INTERNAL_ASSERT(replacer.replaced_expr_ != nullptr); auto ret_expr = replacer.replaced_expr_; @@ -441,10 +354,10 @@ class ReplaceExprInput : public kir::MutableIrVisitor { return ret_expr; } - static std::vector replace( - const std::vector& scope, - const std::unordered_map& replacement_map) { - std::vector ret_expr; + static std::vector replace( + const std::vector& scope, + const std::unordered_map& replacement_map) { + std::vector ret_expr; ret_expr.reserve(scope.size()); for (auto expr : scope) { @@ -455,20 +368,20 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } private: + // TODO: Replace this with mutator, example of this is done in replace + // symbolic sizes ReplaceExprInput( - kir::Expr* expr, - const std::unordered_map& replacement_map) - : gpu_lower_(GpuLower::current()), - ir_builder_(gpu_lower_->kernel()), - replacement_map_(replacement_map) { + Expr* expr, + const std::unordered_map& replacement_map) + : replacement_map_(replacement_map) { replaced_expr_ = expr; } - c10::optional> - getMaybeInputReplacementMap(kir::Expr* expr) { + c10::optional> getMaybeInputReplacementMap( + Expr* expr) { bool need_replacement = false; - std::unordered_map replaced_val; + std::unordered_map replaced_val; for (auto in : expr->inputs()) { auto replace_it = replacement_map_.find(in); if (replace_it != replacement_map_.end()) { @@ -479,16 +392,15 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } } if (need_replacement) { - return c10::optional>( - replaced_val); + return c10::optional>(replaced_val); } else { return c10::nullopt; } } // IR visitor interface - void visit(kir::ForLoop* for_loop) final { - auto new_for_loop = ir_builder_.create(for_loop); + void handle(kir::ForLoop* for_loop) final { + auto new_for_loop = IrBuilder::create(for_loop); auto replaced_loop_body = replace(for_loop->body().exprs(), replacement_map_); @@ -499,8 +411,8 @@ class ReplaceExprInput : public kir::MutableIrVisitor { replaced_expr_ = new_for_loop; } - void visit(kir::IfThenElse* ite) final { - auto new_ite = ir_builder_.create(ite->predicate()); + void handle(kir::IfThenElse* ite) final { + auto new_ite = IrBuilder::create(ite->predicate()); auto replaced_then_body = replace(ite->thenBody().exprs(), replacement_map_); for (auto new_expr : replaced_then_body) { @@ -516,31 +428,31 @@ class ReplaceExprInput : public kir::MutableIrVisitor { replaced_expr_ = new_ite; } - void visit(kir::UnaryOp* node) final { + void handle(UnaryOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getUnaryOpType(), node->out(), replaced_inputs.value().at(node->in())); } } - void visit(kir::BinaryOp* node) final { + void handle(BinaryOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getBinaryOpType(), node->out(), replaced_inputs.value().at(node->lhs()), replaced_inputs.value().at(node->rhs())); } } - void visit(kir::TernaryOp* node) final { + void handle(TernaryOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getTernaryOpType(), node->out(), replaced_inputs.value().at(node->in1()), replaced_inputs.value().at(node->in2()), @@ -548,29 +460,31 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } } - void visit(kir::ReductionOp* node) final { + void handle(ReductionOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->operation(), + replaced_expr_ = IrBuilder::create( + node->getReductionOpType(), node->init(), node->out(), replaced_inputs.value().at(node->in())); } } - void visit(kir::BroadcastOp* node) final { + void handle(BroadcastOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( - node->out(), replaced_inputs.value().at(node->in())); + replaced_expr_ = IrBuilder::create( + node->out(), + replaced_inputs.value().at(node->in()), + node->getBroadcastDimFlags()); } } - void visit(kir::WelfordOp* node) final { + void handle(WelfordOp* node) final { auto replaced_inputs = getMaybeInputReplacementMap(node); if (replaced_inputs.has_value()) { - replaced_expr_ = ir_builder_.create( + replaced_expr_ = IrBuilder::create( node->outAvg(), node->outVar(), node->outN(), @@ -584,17 +498,15 @@ class ReplaceExprInput : public kir::MutableIrVisitor { } private: - GpuLower* gpu_lower_; - kir::IrBuilder ir_builder_; - kir::Expr* replaced_expr_ = nullptr; - const std::unordered_map& replacement_map_; + Expr* replaced_expr_ = nullptr; + const std::unordered_map& replacement_map_; }; } // namespace -std::vector replaceInputsInExpr( - const std::vector& exprs, - const std::unordered_map& replacement_map) { +std::vector replaceInputsInExpr( + const std::vector& exprs, + const std::unordered_map& replacement_map) { return ReplaceExprInput::replace(exprs, replacement_map); } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 1c8a0df5cd7..4ed6c25e731 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -1,7 +1,7 @@ #pragma once -#include +#include #include #include @@ -19,27 +19,15 @@ namespace cuda { class ThreadPredicateMap; -using IterDomainMap = std::unordered_map; +using IterDomainMap = std::unordered_map; namespace scope_utils { -//! Returns the list of nesting loops starting at `scope` -// Primarily used in indexing, maybe could be moved there -std::vector getLoops(kir::Expr* scope); - -//! Insert expr in scope before ref -//! -//! \warning for kir::IfThenElse we implicitly insert in the "then" branch! -//! -void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr); - //! Create an **empty** Forloop and copy the metadata. -kir::ForLoop* cloneForLoop(kir::IrBuilder& ir_builder, kir::ForLoop* for_loop); +kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop); //! Create an **empty** IfThenElse and copy the metadata. -kir::IfThenElse* cloneIfThenElse( - kir::IrBuilder& ir_builder, - kir::IfThenElse* ite); +kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite); } // namespace scope_utils @@ -74,107 +62,80 @@ std::vector iterDomainInputsOfOrderedAs( const std::vector& of, const std::vector& order); +// Returns if Val is a TensorView or TensorIndex bool isTV(const Val* const); -TORCH_CUDA_CU_API bool isTVOp(const Expr*); - -bool isTVOp(const kir::Expr* expr); - -TensorView* getTVOutput(const Expr*); -kir::TensorView* getTVOutput(const kir::Expr*); - -bool isScalarOp(const Expr*); - -// TODO(kir): remove -Expr* asExpr(Statement*); +// Returns is Expr is a TensorView or TensorIndex Expr. +TORCH_CUDA_CU_API bool isTvOp(const Expr*); -// TODO(kir): Remove in favor of ->as() -TensorView* asTV(Val*); - -//! Get kir::TensorView potentially via kir::TensorIndex. Returns nullptr if -//! cast fails. -kir::TensorView* getTv(kir::Val*); - -//! Get only kir::TensorView potentially via kir::TensorIndex. -std::vector getTvs(const std::vector& vals); - -//! Get kir::TensorView potentially via kir::TensorIndex. Error if cast fails. -kir::TensorView* asTv(kir::Val*); - -//! Get kir::TensorView potentially via kir::TensorIndex. Error if cast fails. -std::vector asTvs(const std::vector& vals); +// Returns the first output of Expr that is a TensorView +TensorView* getTvOutput(const Expr*); bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map); -bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map); - -// expr_replacement_map maps an expression to its replacement. -// -// The applyReplacement function serves two purposes. -// -// 1. If expr is found in expr_replacement_map, return the value for expr key. -// Otherwise, return the original expression. -// -// 2. If a replacement is not found and the expression is a ForLoop or an -// IfThenElse, it modifies the expressions in its scope by running the -// handle_scope function -// -// The handle_scope function iterates over the expressions in the scope. -// For each expression, it updates the expression the value returned by -// applyReplacement. -kir::Expr* applyReplacements( - const std::unordered_map& expr_replacement_map, - kir::Expr* expr); //! Returns the Fuser iterdomain that maps to the thread dimension grouped //! to warps. Returns nullopt if the reduction is not to be lowered to //! a warp reduction. -c10::optional getMaybeWarpReductionDim( - const kir::ReductionOp* node); - c10::optional getMaybeWarpReductionDim(const ReductionOp* node); +bool isScalarOp(const Expr*); + +//! Get TensorView potentially via kir::TensorIndex. Returns nullptr if +//! cast fails. +TensorView* getTv(Val*); + +//! Get only TensorView potentially via kir::TensorIndex. +std::vector getTvs(const std::vector& vals); + //! Return true if axis is derived from a root axis that is an input //! to a CA leaf axis. bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis); -std::unordered_map getParallelDomains( - kir::Val* val); +std::unordered_map getParallelDomains( + Val* val); } // namespace ir_utils namespace loop_utils { -// I wanted to make the tv's in these util functions constant, but that started -// a long const-ness project going into TensorView (making functions const -// there) then into lower_loops where we sort exprs. -// TODO: We should fix this when we have some time. - -// Figure out which loop the allocation needs to be in. Returns nullptr if -// outside the first loop in loops. Also find out which index in tv the -// first dimension that needs to be allocated is. Meaning we need to allocate -// that local axis and above. -// TODO: Only remaining use of this is in index compute, remove use from there, -// or refactor and use in lower_allocation -std::pair getAllocPoint( - const TensorView* tv, - const std::vector& loops, - const std::unordered_map& id_map, - bool use_id_map); +struct BasicAllocInfo { + // The for loop that the initialization of this allocation must be + // placed in, nullptr if not within a loop + kir::ForLoop* init_for_loop = nullptr; + + // Keep track of the actual allocation loop. This can be different + // from init_for_loop only with unswitched shared memory allocations, + // which are moved outer loops to avoid duplicated allocations. This means + // that the alloc position may be outside what's expected. Most applications + // outside lower_allocation is likely looking for init_for_loop which is + // more directly related to how large an allocation is and how it's used. + // (see issue #1133). + kir::ForLoop* alloc_for_loop = nullptr; + + // The allocation position relative to buffer IDs, it could be outside the + // compute at position if it's shared memory with a compute at inside an + // unswitch + size_t alloc_pos = 0; +}; -std::pair getAllocPoint( +// Fill the above allocation struct based on provided information. id_map is +// used if we're looking at a producer tensor but loops on a consumer tensor. +BasicAllocInfo getAllocInformation( const TensorView* tv, - const std::vector& loops); + const std::vector& loops, + const std::unordered_map& id_map = {}, + bool use_id_map = false); } // namespace loop_utils // Replace value pass on Kernel IR. -// Replace each use of any kir::Val* that apears in the given `replacement_map` +// Replace each use of any Val* that apears in the given `replacement_map` // Keeps the predicate carried by each expr // // Warning: Blindly replaces all use based on pointer // Warning: May invalidate indexing if replacing uses of allocated values -std::vector replaceInputsInExpr( - const std::vector& exprs, - const std::unordered_map& replacement_map); +std::vector replaceInputsInExpr( + const std::vector& exprs, + const std::unordered_map& replacement_map); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 0579e44dcd6..25ba76ee71b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -319,7 +318,7 @@ class VectorizeValidator : public OptInDispatch { vector_size, " however, vector sizes only upto and including 16 bytes are supported."); - auto replay_exprs = ExprSort::getExprs(fusion, {v_id}); + auto replay_exprs = StmtSort::getExprs(fusion, {v_id}, false); VectorizeValidator validator(v_id); @@ -463,6 +462,14 @@ void validateParallelizationOfTensor(TensorView* tv) { continue; } + // It doesn't matter if this axis is a non-concretized broadcast + // TODO: merging broadcast and non-broadcast + if (axis->isBroadcast() && + !GpuLower::current()->concretizedBroadcastDomains().isConcretized( + axis)) { + continue; + } + TORCH_INTERNAL_ASSERT( !pt_map.get(ptype), "Multiple use of ", @@ -489,7 +496,7 @@ void validateParallelizationOfTensor(TensorView* tv) { ". The tensor is parallelized with ", predicated_parallel_types.toString(), ", but it's invalid to use the types as the tensor is also predicated with them.", - ", thread prd: ", + ", thread pred: ", thread_pred.limited_types.toString()); } @@ -503,10 +510,10 @@ void validateParallelize(Fusion* fusion) { const auto& loop_map = GpuLower::current()->caLoopMap(); const auto& pred_map = GpuLower::current()->threadPredMap(); - auto exprs = ExprSort::getExprs(fusion); + auto exprs = StmtSort::getExprs(fusion); for (auto expr : exprs) { - if (!ir_utils::isTVOp(expr)) { + if (!ir_utils::isTvOp(expr)) { continue; } // Validate parallelization of each consumer by itself @@ -630,7 +637,7 @@ namespace { // each tensor that needs to be computed. std::unordered_map> getLiveRangeOffsets( Fusion* fusion) { - auto exprs = ExprSort::getExprs(fusion); + auto exprs = StmtSort::getExprs(fusion); std::unordered_map> map; @@ -760,7 +767,9 @@ void validatePartialSplit(Fusion* fusion) { auto range_info = getLiveRangeOffsets(fusion); for (auto tv : ir_utils::allTvs(fusion)) { - auto exprs = ir_utils::historyOf(tv); + auto exprs = StmtSort::getExprs( + tv->fusion(), + {tv->domain()->domain().begin(), tv->domain()->domain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // When the start and stop offsets are not zero, make sure the // range defined by the split includes the required range to diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 89de85026ee..115df13c322 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include diff --git a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp index eaddf7faea3..630d3128e78 100644 --- a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include #include @@ -18,20 +18,19 @@ namespace { //! and their corresponding allocations class EliminateDeadBroadcastAndAllocate { public: - static std::vector run(const std::vector& exprs) { + static std::vector run(const std::vector& exprs) { EliminateDeadBroadcastAndAllocate dce(exprs); return dce.result_exprs_; } private: - EliminateDeadBroadcastAndAllocate(const std::vector& exprs) - : ir_builder_(GpuLower::current()->kernel()) { + EliminateDeadBroadcastAndAllocate(const std::vector& exprs) { findLiveTvs(exprs); findDeadTvs(); eliminateDeadCode(exprs); } - void findLiveTvs(const std::vector& exprs) { + void findLiveTvs(const std::vector& exprs) { for (auto expr : exprs) { if (auto for_loop = dynamic_cast(expr)) { findLiveTvs(for_loop->body().exprs()); @@ -44,11 +43,10 @@ class EliminateDeadBroadcastAndAllocate { if (auto allocate = dynamic_cast(expr)) { if (allocate->memoryType() == MemoryType::Local) { - if (auto kir_tv = - dynamic_cast(allocate->buffer())) { + if (auto tv = dynamic_cast(allocate->buffer())) { // We know only tvs that we'd want to consider are broadcast outputs - if (kir_tv->fuserTv()->definition()->isA()) { - candidate_tv_set_.insert(kir_tv); + if (tv->definition()->isA()) { + candidate_tv_set_.insert(tv); } } } @@ -72,18 +70,18 @@ class EliminateDeadBroadcastAndAllocate { } } - void eliminateDeadCode(const std::vector& exprs) { + void eliminateDeadCode(const std::vector& exprs) { result_exprs_ = eliminateDeadCodeInScope(exprs); } - bool shouldEliminate(kir::Expr* expr) { + bool shouldEliminate(Expr* expr) { if (auto allocate = dynamic_cast(expr)) { - if (auto buffer_tv = dynamic_cast(allocate->buffer())) { + if (auto buffer_tv = dynamic_cast(allocate->buffer())) { if (dead_tvs_.count(buffer_tv)) { return true; } } - } else if (auto broadcast = dynamic_cast(expr)) { + } else if (auto broadcast = dynamic_cast(expr)) { if (auto out_ti = dynamic_cast(broadcast->out())) { if (dead_tvs_.count(out_ti->view())) { return true; @@ -95,9 +93,8 @@ class EliminateDeadBroadcastAndAllocate { //! Returns a new vector of exprs with dead exprs //! eliminated. - std::vector eliminateDeadCodeInScope( - const std::vector& exprs) { - std::vector result_exprs; + std::vector eliminateDeadCodeInScope(const std::vector& exprs) { + std::vector result_exprs; for (auto expr : exprs) { auto result_expr = expr; @@ -128,7 +125,7 @@ class EliminateDeadBroadcastAndAllocate { // TODO: we will need a kernel_ir cloner to make this // kind of logic re-usable. - auto new_loop = scope_utils::cloneForLoop(ir_builder_, for_loop); + auto new_loop = scope_utils::cloneForLoop(for_loop); for (auto expr : new_loop_body) { new_loop->body().push_back(expr); @@ -143,7 +140,7 @@ class EliminateDeadBroadcastAndAllocate { return nullptr; } - auto new_ite = scope_utils::cloneIfThenElse(ir_builder_, ite); + auto new_ite = scope_utils::cloneIfThenElse(ite); for (auto expr : new_then_body) { new_ite->thenBody().push_back(expr); @@ -155,12 +152,11 @@ class EliminateDeadBroadcastAndAllocate { } private: - std::unordered_set live_tvs_; - std::unordered_set dead_tvs_; - std::unordered_set candidate_tv_set_; + std::unordered_set live_tvs_; + std::unordered_set dead_tvs_; + std::unordered_set candidate_tv_set_; - std::vector result_exprs_; - kir::IrBuilder ir_builder_; + std::vector result_exprs_; }; //! A pass to eliminate redundant parallel broadcasts that are consumers @@ -189,9 +185,9 @@ class EliminateDeadBroadcastAndAllocate { //! //! 3. EliminateDeadBroadcastAndAllocate removes the broadcast ops //! and corresponding allocations if they're un-used after step 2. -class FuseBroadcastWithWarpReduce { +class FuseBroadcastWithWarpReduce : private kir::IrVisitor { public: - static std::vector fuse(const std::vector& exprs) { + static std::vector fuse(const std::vector& exprs) { FuseBroadcastWithWarpReduce fuse_broadcast_map(exprs); const auto replaced_inputs = replaceInputsInExpr(exprs, fuse_broadcast_map.val_replacement_map_); @@ -199,70 +195,50 @@ class FuseBroadcastWithWarpReduce { } private: - FuseBroadcastWithWarpReduce(const std::vector& exprs) { + FuseBroadcastWithWarpReduce(const std::vector& exprs) { // open stack space for global scope - // The scope stack for kir_tv_to_allocate wouldn't be needed + // The scope stack for tv_to_allocate wouldn't be needed // if the allocations are guaranteed to be once and unique, // which can currently be assumed but this pass tries not // to rely on this assumption. - running_kir_tv_to_allocate_map_.emplace_back( - std::make_unique< - std::unordered_map>()); + running_tv_to_allocate_map_.emplace_back( + std::make_unique>()); running_visible_allocation_stack_.emplace_back( std::make_unique>()); - - for (auto expr : exprs) { - handle(expr); - } + kir::IrVisitor::handle(exprs); } - void handle(kir::Expr* expr) { - if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - return; - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - return; - } - - // Process expr inputs if needs replacement - for (auto inp : expr->inputs()) { - if (auto input_ti = dynamic_cast(inp)) { - auto replace = findMaybeReplacedTensorIndex(input_ti); - if (replace.has_value()) { - val_replacement_map_[input_ti] = replace.value(); + void handle(Expr* expr) final { + if (ir_utils::isTvOp(expr)) { + // Process expr inputs if needs replacement + for (auto inp : expr->inputs()) { + if (auto input_ti = dynamic_cast(inp)) { + auto replace = findMaybeReplacedTensorIndex(input_ti); + if (replace.has_value()) { + val_replacement_map_[input_ti] = replace.value(); + } } } } - - // Handle reduction definitions - if (auto reduction = dynamic_cast(expr)) { - handle(reduction); - } else if (auto broadcast = dynamic_cast(expr)) { - handle(broadcast); - } else if (auto allocate = dynamic_cast(expr)) { - handle(allocate); - } } - bool openLoopNestLevel(kir::IterDomain* id) { - if (id->isThread() || id->parallelType() == ParallelType::Unswitch) { + bool openLoopNestLevel(IterDomain* id) { + if (id->isThread() || id->getParallelType() == ParallelType::Unswitch) { return false; } - if (id->parallelType() == ParallelType::Serial || - id->parallelType() == ParallelType::Unroll) { + if (id->getParallelType() == ParallelType::Serial || + id->getParallelType() == ParallelType::Unroll) { return !id->isBroadcast(); } return true; } - void handle(kir::ForLoop* for_loop) { + void handle(kir::ForLoop* for_loop) final { // Keep track of visible reduction outputs bool open_nest_level = openLoopNestLevel(for_loop->iter_domain()); if (open_nest_level) { - running_kir_tv_to_allocate_map_.emplace_back( - std::make_unique< - std::unordered_map>()); + running_tv_to_allocate_map_.emplace_back( + std::make_unique>()); running_visible_allocation_stack_.emplace_back( std::make_unique>()); } @@ -270,12 +246,12 @@ class FuseBroadcastWithWarpReduce { handle(expr); } if (open_nest_level) { - running_kir_tv_to_allocate_map_.pop_back(); + running_tv_to_allocate_map_.pop_back(); running_visible_allocation_stack_.pop_back(); } } - void handle(kir::IfThenElse* ite) { + void handle(kir::IfThenElse* ite) final { running_visible_allocation_stack_.emplace_back( std::make_unique>()); for (auto expr : ite->thenBody().exprs()) { @@ -292,15 +268,14 @@ class FuseBroadcastWithWarpReduce { //! Place this allocate on the list of currently visible allocations, //! organized by loop nest level. - void handle(kir::Allocate* allocate) { + void handle(kir::Allocate* allocate) final { if (allocate->memoryType() != MemoryType::Local) { return; } - if (auto kir_tv = dynamic_cast(allocate->buffer())) { - auto fuser_tv = kir_tv->fuserTv(); - if (fuser_tv->definition()) { - if (fuser_tv->definition()->isA() || - fuser_tv->definition()->isA()) { + if (auto tv = dynamic_cast(allocate->buffer())) { + if (tv->definition()) { + if (tv->definition()->isA() || + tv->definition()->isA()) { running_visible_allocation_stack_.back()->push_back(allocate); } } @@ -311,18 +286,18 @@ class FuseBroadcastWithWarpReduce { //! returns the replaced TensorIndex if so. c10::optional findMaybeReplacedTensorIndex( kir::TensorIndex* tensor_index) { - auto kir_tv = tensor_index->view(); - auto tensor_index_it = running_tv_replacement_map_.find(kir_tv); + auto tv = tensor_index->view(); + auto tensor_index_it = running_tv_replacement_map_.find(tv); if (tensor_index_it != running_tv_replacement_map_.end()) { return tensor_index_it->second; } return c10::nullopt; } - //! Iteratve backwards on the currently visible loop scopes + //! Iterate backwards on the currently visible loop scopes //! and find the first allocation corresponding to the //! given tv. - kir::Allocate* getActiveAllocateFor(kir::TensorView* tv) { + kir::Allocate* getActiveAllocateFor(TensorView* tv) { for (auto frame_it = running_visible_allocation_stack_.rbegin(); frame_it != running_visible_allocation_stack_.rend(); frame_it++) { @@ -340,19 +315,10 @@ class FuseBroadcastWithWarpReduce { return nullptr; } - Expr* getFuserTVExpr(kir::Expr* expr) { - auto out = expr->outputs()[0]; - auto out_ti = dynamic_cast(out); - if (!out_ti) { - return nullptr; - } - return out_ti->view()->fuserTv()->definition(); - } - - bool isOpInputRegisterTV(kir::Expr* expr) { + bool isOpInputRegisterTV(Expr* expr) { for (auto inp : expr->inputs()) { if (auto inp_ti = dynamic_cast(inp)) { - if (inp_ti->view()->memoryType() != MemoryType::Local) { + if (inp_ti->view()->getMemoryType() != MemoryType::Local) { return false; } } @@ -361,10 +327,10 @@ class FuseBroadcastWithWarpReduce { return true; } - bool isOpOutputRegisterTV(kir::Expr* expr) { + bool isOpOutputRegisterTV(Expr* expr) { for (auto out : expr->outputs()) { if (auto out_ti = dynamic_cast(out)) { - if (out_ti->view()->memoryType() != MemoryType::Local) { + if (out_ti->view()->getMemoryType() != MemoryType::Local) { return false; } } @@ -374,8 +340,8 @@ class FuseBroadcastWithWarpReduce { } //! Updates map of serially visible reduction tvs, see comment on - //! running_kir_tv_to_allocate_map_. - void handle(kir::ReductionOp* reduction) { + //! running_tv_to_allocate_map_. + void handle(ReductionOp* reduction) final { if (!isOpOutputRegisterTV(reduction)) { return; } @@ -386,11 +352,11 @@ class FuseBroadcastWithWarpReduce { // keep track of which reduction buffer this expr writes into auto reduction_allocate = getActiveAllocateFor(reduction_ti_out->view()); - running_kir_tv_to_allocate_map_.back()->operator[]( - reduction_ti_out->view()) = reduction_allocate; + running_tv_to_allocate_map_.back()->operator[](reduction_ti_out->view()) = + reduction_allocate; } - void handle(kir::BroadcastOp* broadcast) { + void handle(BroadcastOp* broadcast) final { if (!isOpInputRegisterTV(broadcast) || !isOpOutputRegisterTV(broadcast)) { return; } @@ -400,9 +366,9 @@ class FuseBroadcastWithWarpReduce { //! Detects if this broadcast can be fused with the producer reduction. //! adds the output of broadcast to replacement map if all above mentioned //! conditions check. - void tryAddOutputToReplaceMap(kir::BroadcastOp* broadcast) { + void tryAddOutputToReplaceMap(BroadcastOp* broadcast) { if (auto in_ti = dynamic_cast(broadcast->in())) { - if (!in_ti->view()->fuserTv()->definition()->isA()) { + if (!in_ti->view()->definition()->isA()) { return; } auto out_ti = broadcast->out()->as(); @@ -410,15 +376,14 @@ class FuseBroadcastWithWarpReduce { // check reduction-broadcast mapping: if (!canFuseBroadcastWithWarpReduction( - out_tv->fuserTv()->definition()->as())) { + out_tv->definition()->as())) { return; } // check buffers are size-1 auto reduction_allocate_it = - running_kir_tv_to_allocate_map_.back()->find(in_ti->view()); - if (reduction_allocate_it == - running_kir_tv_to_allocate_map_.back()->end()) { + running_tv_to_allocate_map_.back()->find(in_ti->view()); + if (reduction_allocate_it == running_tv_to_allocate_map_.back()->end()) { // The producer reduction is not in the serially visible scope, // as defined in openLoopNestLevel. There still could be some // cases that we could fuse but disabled for simplicity. @@ -444,7 +409,7 @@ class FuseBroadcastWithWarpReduce { return; } - // Write the kir_tv in to the replacement map + // Write the tv in to the replacement map // so the future uses of this tv will put // the tensorIndex's in the actual replacement map. running_tv_replacement_map_[out_tv] = in_ti; @@ -515,7 +480,7 @@ class FuseBroadcastWithWarpReduce { //! could need some extension for more precise scope based analysis in the //! future especially if we have more complex IfThenElse blocks than //! predicates and unroll. - std::unordered_map + std::unordered_map running_tv_replacement_map_; //! Keeps track of the allocated buffers that the exprs will write/read @@ -531,21 +496,20 @@ class FuseBroadcastWithWarpReduce { //! visibility on the generated kernel. The model of IfThenElse assumes the //! only ITE's we have are predicates and unrolls, which might need to be //! more precise. - std::vector< - std::unique_ptr>> - running_kir_tv_to_allocate_map_; + std::vector>> + running_tv_to_allocate_map_; //! This map is the final output of this pass and a val replacement map will //! be run using //! it. All keys and values are TensorIndex's, and before this pass each //! TensorIndex is uniquely generated by lower_index pass for each access of - //! a kir_tv. - std::unordered_map val_replacement_map_; + //! a tv. + std::unordered_map val_replacement_map_; }; } // namespace -std::vector fuseWarpReduce(const std::vector exprs) { +std::vector fuseWarpReduce(const std::vector exprs) { return FuseBroadcastWithWarpReduce::fuse(exprs); } diff --git a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h index 785c0b59122..7480809c7dc 100644 --- a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h +++ b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.h @@ -13,7 +13,7 @@ struct WarpPaddedParallelInfo { bool has_warp_reduction = false; }; -std::vector fuseWarpReduce(const std::vector exprs); +std::vector fuseWarpReduce(const std::vector exprs); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index ee1bea81535..0f5967c004d 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -141,6 +141,25 @@ class CudaFusionManager { int32_t next_unique_id_ = 0; }; +// Mark string attribute in alias-copy nodes to enable its implementation +// in the fallback path. +void enableAliasCopyNodes(const std::shared_ptr& graph, Block* block) { + static std::unordered_set alias_copy_op( + {prim::view_copy, + prim::reshape_copy, + prim::squeeze_copy, + prim::unsqueeze_copy}); + + for (Node* n : block->nodes()) { + for (Block* b : n->blocks()) { + enableAliasCopyNodes(graph, b); + } + if (alias_copy_op.find(n->kind()) != alias_copy_op.end()) { + n->s_(attr::name, "CudaFusionGroup"); + } + } +} + } // namespace void compileCudaFusionGroup(Node* fusion_node) { @@ -194,6 +213,7 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { // copying graph here since we are eliminating shape information; auto copied_graph = fusion_node->g(attr::Subgraph)->copy(); EraseShapeInformation(copied_graph); + enableAliasCopyNodes(copied_graph, copied_graph->block()); InterpreterState{Code(copied_graph, "fallback_cuda_fuser")}.run(stack); }; diff --git a/torch/csrc/jit/codegen/cuda/manager.h b/torch/csrc/jit/codegen/cuda/manager.h index 39c97478eff..4b725cd80bc 100644 --- a/torch/csrc/jit/codegen/cuda/manager.h +++ b/torch/csrc/jit/codegen/cuda/manager.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include /* diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 8d13f1e299e..c24e444eb56 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -10,143 +11,177 @@ namespace jit { namespace fuser { namespace cuda { -// MUTATE FUNCTIONS FOR VALS +void OptOutMutator::mutate(Statement* s) { + Statement::mutatorDispatch(this, s); +} + +void OptOutMutator::mutate(Expr* e) { + Expr::mutatorDispatch(this, e); +} + +void OptOutMutator::mutate(Val* v) { + Val::mutatorDispatch(this, v); +} + +void OptOutMutator::registerMutation(Val* val, Val* mutation) { + bool val_is_ns = val->vtype() == ValType::NamedScalar; + bool mutation_is_ns = mutation->vtype() == ValType::NamedScalar; + bool val_is_scalar = val->vtype() == ValType::Scalar; + bool mutation_is_scalar = mutation->vtype() == ValType::Scalar; + TORCH_INTERNAL_ASSERT( + mutation->dtype() == val->dtype() && + (mutation->vtype() == val->vtype() || + ((val_is_ns && mutation_is_scalar) || + (mutation_is_ns && val_is_scalar))), + "Mutations are not allowed to change types, tried to go from: (", + val->vtype(), + ", ", + val->dtype(), + ") to: (", + mutation->vtype(), + ", ", + mutation->dtype(), + ")"); + mutations[val] = mutation; +} + +void OptOutMutator::mutate(Bool* b) {} + +void OptOutMutator::mutate(Double* d) {} -Statement* OptOutMutator::mutate(IterDomain* id) { - Val* start = mutateAsVal(id->start())->asVal(); - Val* extent = mutateAsVal(id->extent())->asVal(); - Val* stop_offset = mutateAsVal(id->stopOffset())->asVal(); +void OptOutMutator::mutate(Int* i) {} + +void OptOutMutator::mutate(NamedScalar* ns) {} + +void OptOutMutator::mutate(IterDomain* id) { + Val* start = maybeMutated(id->start()); + Val* extent = maybeMutated(id->extent()); + Val* stop_offset = maybeMutated(id->stopOffset()); if (start->sameAs(id->start()) && extent->sameAs(id->extent()) && stop_offset->sameAs(id->stopOffset())) { - return id; + return; } - Val* mutated_val = new IterDomain( + Val* mutated_val = IrBuilder::create( + id->container(), start, extent, stop_offset, id->getParallelType(), id->getIterType(), id->isRFactorProduct()); + if (id->hasPaddingToMultipleOfWarp()) { + mutated_val->as()->padToMultipleOfWarp( + id->getMaybeSizeAfterPadding()); + } registerMutation(id, mutated_val); - return mutated_val; } -Statement* OptOutMutator::mutate(TensorDomain* td) { - std::vector dom; +void OptOutMutator::mutate(TensorDomain* td) { bool mutated = false; - for (const auto i : c10::irange(td->nDims())) { - IterDomain* id = mutateAsVal(td->axis(i))->as(); - dom.push_back(id); - if (!id->sameAs(td->axis(i))) - mutated = true; - } - if (mutated) { - Val* mutated_val = new TensorDomain( - td->getRootDomain(), td->getRFactorDomain(), dom, td->contiguity()); - registerMutation(td, mutated_val); - return mutated_val; + auto updateIdVec = [&](const std::vector& ids) { + std::vector updated_ids; + for (auto id : ids) { + auto updated_id = maybeMutated(id)->as(); + updated_ids.push_back(updated_id); + if (!updated_id->sameAs(id)) { + mutated = true; + } + } + return updated_ids; + }; + + std::vector root_dom = updateIdVec(td->getRootDomain()); + std::vector rfactor_dom = td->hasRFactor() + ? updateIdVec(td->getMaybeRFactorDomain()) + : std::vector(); + std::vector domain = updateIdVec(td->domain()); + + if (!mutated) { + return; } - return td; -} -Statement* OptOutMutator::mutate(TensorView* tv) { - TensorDomain* td = mutateAsVal(tv->domain())->as(); + Val* mutated_val = IrBuilder::create( + td->container(), root_dom, rfactor_dom, domain, td->contiguity()); + registerMutation(td, mutated_val); +} +void OptOutMutator::mutate(TensorView* tv) { + TensorDomain* td = maybeMutated(tv->domain())->as(); if (!tv->domain()->sameAs(td)) { - TensorView* mutated_tv = new TensorView(td, tv->getDataType().value()); - registerMutation(tv, mutated_tv); - return mutated_tv; + tv->setDomain(td); } - return tv; -} - -Statement* OptOutMutator::mutate(Bool* b) { - return b; + // Don't register tv mutations as we just want to update the TD } -Statement* OptOutMutator::mutate(Double* d) { - return d; +void OptOutMutator::mutate(kir::Predicate*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } -Statement* OptOutMutator::mutate(Int* i) { - return i; -} - -Statement* OptOutMutator::mutate(NamedScalar* ns) { - return ns; +void OptOutMutator::mutate(kir::TensorIndex*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } // MUTATE FUNCTIONS FOR EXPRESSIONS. +void OptOutMutator::mutate(UnaryOp* uop) { + Val* out = maybeMutated(uop->out()); + Val* in = maybeMutated(uop->in()); -Statement* OptOutMutator::mutate(Split* s) { - IterDomain* ot = mutateAsVal(s->outer())->as(); - IterDomain* inr = mutateAsVal(s->inner())->as(); - IterDomain* in = mutateAsVal(s->in())->as(); - Val* fact = mutateAsVal(s->factor())->as(); - - if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && - in->sameAs(s->in()) && areEqualScalars(fact, s->factor())) { - return s; + if (out->sameAs(uop->out()) && in->sameAs(uop->in())) { + return; } - FusionGuard::getCurFusion()->removeExpr(s); - return new Split(ot, inr, in, fact, s->innerSplit()); + auto container = uop->container(); + auto uop_type = uop->getUnaryOpType(); + container->removeExpr(uop); + IrBuilder::create(container, uop_type, out, in); } -Statement* OptOutMutator::mutate(Merge* m) { - IterDomain* ot = mutateAsVal(m->out())->as(); - IterDomain* otr = mutateAsVal(m->outer())->as(); - IterDomain* in = mutateAsVal(m->inner())->as(); +void OptOutMutator::mutate(BinaryOp* bop) { + Val* out = maybeMutated(bop->out()); + Val* lhs = maybeMutated(bop->lhs()); + Val* rhs = maybeMutated(bop->rhs()); - if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && in->sameAs(m->inner())) - return m; - - FusionGuard::getCurFusion()->removeExpr(m); - return new Merge(ot, otr, in); -} - -Statement* OptOutMutator::mutate(UnaryOp* uop) { - Val* out = mutateAsVal(uop->out())->asVal(); - Val* in = mutateAsVal(uop->in())->asVal(); + if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs()) { + return; + } - if (out->sameAs(uop->out()) && in->sameAs(uop->in())) - return uop; - FusionGuard::getCurFusion()->removeExpr(uop); - return new UnaryOp(uop->getUnaryOpType(), out, in); + auto container = bop->container(); + auto bop_type = bop->getBinaryOpType(); + container->removeExpr(bop); + IrBuilder::create(container, bop_type, out, lhs, rhs); } -Statement* OptOutMutator::mutate(BinaryOp* bop) { - Val* out = mutateAsVal(bop->out())->asVal(); - Val* lhs = mutateAsVal(bop->lhs())->asVal(); - Val* rhs = mutateAsVal(bop->rhs())->asVal(); - if (out == bop->out() && lhs == bop->lhs() && rhs == bop->rhs()) - return bop; - FusionGuard::getCurFusion()->removeExpr(bop); - return new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); -} +void OptOutMutator::mutate(TernaryOp* top) { + Val* out = maybeMutated(top->out()); + Val* in1 = maybeMutated(top->in1()); + Val* in2 = maybeMutated(top->in2()); + Val* in3 = maybeMutated(top->in3()); -Statement* OptOutMutator::mutate(TernaryOp* top) { - Val* out = mutateAsVal(top->out())->asVal(); - Val* in1 = mutateAsVal(top->in1())->asVal(); - Val* in2 = mutateAsVal(top->in2())->asVal(); - Val* in3 = mutateAsVal(top->in3())->asVal(); if (out == top->out() && in1 == top->in1() && in2 == top->in2() && - in3 == top->in3()) - return top; - FusionGuard::getCurFusion()->removeExpr(top); - return new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3); + in3 == top->in3()) { + return; + } + + auto container = top->container(); + auto top_type = top->getTernaryOpType(); + container->removeExpr(top); + IrBuilder::create(container, top_type, out, in1, in2, in3); } -Statement* OptOutMutator::mutate(ReductionOp* rop) { - Val* out = mutateAsVal(rop->out())->asVal(); - Val* in = mutateAsVal(rop->in())->asVal(); +void OptOutMutator::mutate(ReductionOp* rop) { + Val* out = maybeMutated(rop->out()); + Val* in = maybeMutated(rop->in()); Val* init = rop->init(); if (out->sameAs(rop->out()) && in->sameAs(rop->in()) && - init->sameAs(rop->init())) - return rop; + init->sameAs(rop->init())) { + return; + } - return new ReductionOp(rop->getReductionOpType(), init, out, in); + auto container = rop->container(); + auto rop_type = rop->getReductionOpType(); + container->removeExpr(rop); + IrBuilder::create(container, rop_type, init, out, in); } namespace { @@ -159,20 +194,18 @@ inline bool compareOptional(Val* a, Val* b) { } // namespace -Statement* OptOutMutator::mutate(WelfordOp* wop) { - Val* out_avg = mutateAsVal(wop->outAvg())->asVal(); - Val* out_var = mutateAsVal(wop->outVar())->asVal(); - Val* out_N = mutateAsVal(wop->outN())->asVal(); +void OptOutMutator::mutate(WelfordOp* wop) { + Val* out_avg = maybeMutated(wop->outAvg()); + Val* out_var = maybeMutated(wop->outVar()); + Val* out_N = maybeMutated(wop->outN()); - Val* in_avg = mutateAsVal(wop->inAvg())->asVal(); - Val* in_var = wop->inVar() ? mutateAsVal(wop->inVar())->asVal() : nullptr; - Val* in_N = mutateAsVal(wop->inN())->asVal(); + Val* in_avg = maybeMutated(wop->inAvg()); + Val* in_var = wop->inVar() ? maybeMutated(wop->inVar()) : nullptr; + Val* in_N = maybeMutated(wop->inN()); - Val* init_avg = - wop->initAvg() ? mutateAsVal(wop->initAvg())->asVal() : nullptr; - Val* init_var = - wop->initVar() ? mutateAsVal(wop->initVar())->asVal() : nullptr; - Val* init_N = mutateAsVal(wop->initN())->asVal(); + Val* init_avg = wop->initAvg() ? maybeMutated(wop->initAvg()) : nullptr; + Val* init_var = wop->initVar() ? maybeMutated(wop->initVar()) : nullptr; + Val* init_N = maybeMutated(wop->initN()); const bool out_compare = out_avg->sameAs(wop->outAvg()) && out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN()); @@ -182,56 +215,163 @@ Statement* OptOutMutator::mutate(WelfordOp* wop) { compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN()); if (out_compare && init_compare && in_compare) { - return wop; - } else { - return new WelfordOp( - out_avg, - out_var, - out_N, - init_avg, - init_var, - init_N, - in_avg, - in_var, - in_N); + return; } + + auto container = wop->container(); + container->removeExpr(wop); + IrBuilder::create( + container, + out_avg, + out_var, + out_N, + init_avg, + init_var, + init_N, + in_avg, + in_var, + in_N); } -Statement* OptOutMutator::mutate(BroadcastOp* bop) { - return bop; +void OptOutMutator::mutate(BroadcastOp* bop) { + Val* out = maybeMutated(bop->out()); + Val* in = maybeMutated(bop->in()); + + if (out->sameAs(bop->out()) && in->sameAs(bop->in())) { + return; + } + + auto container = bop->container(); + auto flags = bop->getBroadcastDimFlags(); + container->removeExpr(bop); + IrBuilder::create(container, out, in, flags); } -Statement* OptOutMutator::mutate(TransposeOp* top) { - return top; +void OptOutMutator::mutate(TransposeOp* top) { + TensorView* out = maybeMutated(top->out())->as(); + TensorView* in = maybeMutated(top->in())->as(); + + if (out->sameAs(top->out()) && in->sameAs(top->in())) { + return; + } + + auto container = top->container(); + auto new2old = top->new2old(); + container->removeExpr(top); + IrBuilder::create(container, out, in, new2old); } -Statement* OptOutMutator::mutate(ShiftOp* sop) { - Val* out = mutateAsVal(sop->out())->asVal(); - Val* in = mutateAsVal(sop->in())->asVal(); +void OptOutMutator::mutate(ShiftOp* sop) { + Val* out = maybeMutated(sop->out())->asVal(); + Val* in = maybeMutated(sop->in())->asVal(); + + if (out->sameAs(sop->out()) && in->sameAs(sop->in())) { + return; + } - if (out->sameAs(sop->out()) && in->sameAs(sop->in())) - return sop; auto offsets = sop->offsets(); - FusionGuard::getCurFusion()->removeExpr(sop); - return new ShiftOp(out, in, offsets, sop->pad()); + auto pad_width = sop->padWidth(); + auto container = sop->container(); + container->removeExpr(sop); + IrBuilder::create(container, out, in, offsets, pad_width); } -Statement* OptOutMutator::mutate(GatherOp* op) { - Val* out = mutateAsVal(op->out())->asVal(); - Val* in = mutateAsVal(op->in())->asVal(); +void OptOutMutator::mutate(GatherOp* op) { + Val* out = maybeMutated(op->out())->asVal(); + Val* in = maybeMutated(op->in())->asVal(); + + if (out->sameAs(op->out()) && in->sameAs(op->in())) { + return; + } - if (out->sameAs(op->out()) && in->sameAs(op->in())) - return op; auto window_shape = op->windowShape(); auto pad_width = op->padWidth(); - FusionGuard::getCurFusion()->removeExpr(op); - return new GatherOp(out, in, window_shape, pad_width); + auto container = op->container(); + container->removeExpr(op); + IrBuilder::create(container, out, in, window_shape, pad_width); +} + +void OptOutMutator::mutate(ViewOp* vop) { + TensorView* out = maybeMutated(vop->out())->as(); + TensorView* in = maybeMutated(vop->in())->as(); + + if (out->sameAs(vop->out()) && in->sameAs(vop->in())) { + return; + } + + auto container = vop->container(); + container->removeExpr(vop); + IrBuilder::create(container, out, in); +} + +void OptOutMutator::mutate(Split* s) { + IterDomain* ot = maybeMutated(s->outer())->as(); + IterDomain* inr = maybeMutated(s->inner())->as(); + IterDomain* in = maybeMutated(s->in())->as(); + Val* fact = maybeMutated(s->factor())->as(); + Val* start_offset = maybeMutated(s->startOffset()); + Val* stop_offset = maybeMutated(s->stopOffset()); + + if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && + in->sameAs(s->in()) && areEqualScalars(fact, s->factor()) && + start_offset->sameAs(s->startOffset()) && + stop_offset->sameAs(s->stopOffset())) { + return; + } + + auto container = s->container(); + auto inner_split = s->innerSplit(); + container->removeExpr(s); + auto new_node = IrBuilder::create( + container, ot, inr, in, fact, inner_split, start_offset, stop_offset); +} + +void OptOutMutator::mutate(Merge* m) { + IterDomain* ot = maybeMutated(m->out())->as(); + IterDomain* otr = maybeMutated(m->outer())->as(); + IterDomain* in = maybeMutated(m->inner())->as(); + + if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && + in->sameAs(m->inner())) { + return; + } + + auto container = m->container(); + container->removeExpr(m); + auto new_node = IrBuilder::create(container, ot, otr, in); } -Statement* OptOutMutator::mutate(ViewOp* vop) { - return vop; +void OptOutMutator::mutate(kir::Allocate*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::Sync*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::InitMagicZero*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::UpdateMagicZero*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::ForLoop*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::IfThenElse*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::GridReduction*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::GridBroadcast*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::GridWelford*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } +void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { + container->removeExpr(expr); +} } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/mutator.h b/torch/csrc/jit/codegen/cuda/mutator.h index f9ec40ca9f5..433de485cf1 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.h +++ b/torch/csrc/jit/codegen/cuda/mutator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/non_divisible_split.h b/torch/csrc/jit/codegen/cuda/non_divisible_split.h index f17bf2d6246..6706c9f072d 100644 --- a/torch/csrc/jit/codegen/cuda/non_divisible_split.h +++ b/torch/csrc/jit/codegen/cuda/non_divisible_split.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.cpp b/torch/csrc/jit/codegen/cuda/ops/alias.cpp new file mode 100644 index 00000000000..14aff510911 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ops/alias.cpp @@ -0,0 +1,115 @@ +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +//! Transform TensorView according to keep, merge, and split transformations. +//! Trivial reduction and broadcast transformations are handled separately. +//! It is recommend to use the composite ops view function, which will call +//! the analyzeView function to generate the appropriate transformations. +//! +//! For example: +//! original sizes = [2, 10, 40] +//! new_size = [2, 10, 2, 20] +//! auto analysis = analyzeView(TV0, original_sizes, new_sizes) +//! auto TV1 = TV0->view(analysis.transforms); +//! +//! Transforms = [(Keep I0), (Keep I1), (Split I2 by 2)] +//! Before: TV0[I0, I1, I2] +//! After: TV0[I0, I1, 2, ceilDiv(I2, 2)] +//! +TensorView* applyViewTransforms( + TensorView* tv, + const std::vector>& transforms) { + TORCH_INTERNAL_ASSERT( + !tv->hasComputeAt(), + "Cannot modify rfactor domain after compute at has been set."); + + TORCH_INTERNAL_ASSERT(tv->nDims() > 0, "Tried to view a 0-dim TensorView"); + + TORCH_CHECK( + !tv->domain()->hasRFactor(), + "Cannot call view on the same TensorView twice."); + + TORCH_INTERNAL_ASSERT(!transforms.empty()); + + TensorView* consumer = IrBuilder::create( + tv->container(), + tv->domain()->view(transforms), + tv->getDataType().value()); + + IrBuilder::create(tv->container(), consumer, tv); + + return consumer; +} + +} // namespace + +TensorView* view( + TensorView* x, + const std::vector& original_sizes, + const std::vector& new_sizes) { + TORCH_INTERNAL_ASSERT(x->nDims() == original_sizes.size()); + + auto analyze_view = analyzeView(x, original_sizes, new_sizes); + + auto reduction = (!analyze_view.trivial_reduction_axes.empty()) + ? sum(x, analyze_view.trivial_reduction_axes) + : x; + + auto view = (!analyze_view.transforms.empty()) + ? applyViewTransforms(reduction, analyze_view.transforms) + : reduction; + + return (analyze_view.has_broadcast) + ? broadcast(view, analyze_view.broadcast_axes) + : view; +} + +TensorView* squeeze(TensorView* x, const std::vector& sizes) { + TORCH_INTERNAL_ASSERT(x->nDims() == sizes.size()); + + std::vector trivial_reduction_axes; + for (const auto idx : c10::irange(sizes.size())) { + if (sizes[idx] == 1) { + trivial_reduction_axes.push_back(idx); + } + } + return (trivial_reduction_axes.empty()) ? x : sum(x, trivial_reduction_axes); +} + +TensorView* squeeze(TensorView* x, const std::vector& sizes, int dim) { + TORCH_INTERNAL_ASSERT(x->nDims() == sizes.size()); + if (dim < 0) { + dim = (int)(x->nDims()) + dim; + } + TORCH_INTERNAL_ASSERT(dim >= 0 && dim < x->nDims()); + if (sizes[dim] == 1) { + return sum(x, {dim}); + } else { + return set(x); + } +} + +TensorView* unsqueeze(TensorView* x, int dim) { + if (dim < 0) { + dim = (int)(x->nDims()) + dim + 1; + } + TORCH_INTERNAL_ASSERT(dim >= 0 && dim <= x->nDims()); + + std::vector broadcast_axes(x->nDims() + 1, false); + broadcast_axes[dim] = true; + return broadcast(x, broadcast_axes); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.h b/torch/csrc/jit/codegen/cuda/ops/alias.h new file mode 100644 index 00000000000..8003e3268b3 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/ops/alias.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include +#include + +// +// The operations defined in this header is intended as user facing functions. +// The user will provide the necessary input TensorViews and the function will +// create the correct intermediate nodes and return the output TensorViews. +// + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +TORCH_CUDA_CU_API TensorView* view( + TensorView* x, + const std::vector& original_sizes, + const std::vector& new_sizes); + +TORCH_CUDA_CU_API TensorView* squeeze( + TensorView* x, + const std::vector& sizes); + +TORCH_CUDA_CU_API TensorView* squeeze( + TensorView* x, + const std::vector& sizes, + int dim); + +TORCH_CUDA_CU_API TensorView* unsqueeze(TensorView* x, int dim); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ops/all_ops.h b/torch/csrc/jit/codegen/cuda/ops/all_ops.h index 1ebd2bb87f1..07d3eb944e8 100644 --- a/torch/csrc/jit/codegen/cuda/ops/all_ops.h +++ b/torch/csrc/jit/codegen/cuda/ops/all_ops.h @@ -1,4 +1,5 @@ #pragma once #include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.cpp b/torch/csrc/jit/codegen/cuda/ops/composite.cpp index 06bcf2d0494..c01b7230625 100644 --- a/torch/csrc/jit/codegen/cuda/ops/composite.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/composite.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -8,9 +9,10 @@ namespace fuser { namespace cuda { ForwardDropoutResult dropout(TensorView* x, Val* prob) { - auto p1m = sub(new Double(1.), prob); - auto zero_check = add(eq(p1m, new Double(0.)), p1m); - auto scale = div(new Double(1.), zero_check); + auto p1m = sub(IrBuilder::create(x->container(), 1.), prob); + auto zero_check = + add(eq(p1m, IrBuilder::create(x->container(), 0.)), p1m); + auto scale = div(IrBuilder::create(x->container(), 1.), zero_check); return dropout(x, p1m, scale); } @@ -91,13 +93,14 @@ Val* fast_gelu(Val* x) { auto x_cube = mul(x, mul(x, x)); - auto inner_1 = mul(new Double(kKappa), x_cube); + auto inner_1 = mul(IrBuilder::create(x->container(), kKappa), x_cube); auto inner_2 = add(x, inner_1); - auto inner_3 = mul(new Double(kBeta), inner_2); + auto inner_3 = mul(IrBuilder::create(x->container(), kBeta), inner_2); auto tanh_inner = tanh(inner_3); - auto out = mul(x, add(new Double(1.), tanh_inner)); - auto y = mul(new Double(0.5), out); + auto out = + mul(x, add(IrBuilder::create(x->container(), 1.), tanh_inner)); + auto y = mul(IrBuilder::create(x->container(), 0.5), out); return y; } @@ -111,21 +114,25 @@ Val* fast_gelu_backward(Val* dy, Val* x) { auto x_sq = mul(x, x); auto x_cube = mul(x, x_sq); - auto inner_1 = mul(new Double(kKappa), x_cube); + auto inner_1 = mul(IrBuilder::create(x->container(), kKappa), x_cube); auto inner_2 = add(x, inner_1); - auto inner_3 = mul(new Double(kBeta), inner_2); + auto inner_3 = mul(IrBuilder::create(x->container(), kBeta), inner_2); auto tanh_inner = tanh(inner_3); - auto left = mul(new Double(0.5), x); - auto right = add(new Double(1.), tanh_inner); + auto left = mul(IrBuilder::create(x->container(), 0.5), x); + auto right = add(IrBuilder::create(x->container(), 1.), tanh_inner); - auto left_derivative = mul(new Double(0.5), right); + auto left_derivative = + mul(IrBuilder::create(x->container(), 0.5), right); auto tanh_inner_sq = mul(tanh_inner, tanh_inner); - auto tanh_derivative = sub(new Double(1), tanh_inner_sq); + auto tanh_derivative = + sub(IrBuilder::create(x->container(), 1), tanh_inner_sq); - auto constant_mul_x_sq = mul(new Double(kBeta * 3 * kKappa), x_sq); - auto inner_derivative = add(new Double(kBeta), constant_mul_x_sq); + auto constant_mul_x_sq = + mul(IrBuilder::create(x->container(), kBeta * 3 * kKappa), x_sq); + auto inner_derivative = + add(IrBuilder::create(x->container(), kBeta), constant_mul_x_sq); auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative)); auto dx = mul(dy, add(left_derivative, right_derivative)); @@ -139,79 +146,30 @@ Val* gelu_backward(Val* dy, Val* x) { constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5; const double kHalf = 0.5; - auto cdf_1 = mul(x, new Double(M_SQRT1_2)); + auto cdf_1 = mul(x, IrBuilder::create(x->container(), M_SQRT1_2)); auto cdf_2 = erf(cdf_1); - auto cdf_3 = add(cdf_2, new Double(1.)); - auto cdf_4 = mul(cdf_3, new Double(kHalf)); + auto cdf_3 = add(cdf_2, IrBuilder::create(x->container(), 1.)); + auto cdf_4 = mul(cdf_3, IrBuilder::create(x->container(), kHalf)); auto pdf_1 = mul(x, x); - auto pdf_2 = mul(pdf_1, new Double(-kHalf)); + auto pdf_2 = mul(pdf_1, IrBuilder::create(x->container(), -kHalf)); auto pdf_3 = exp(pdf_2); - auto out = addcmul(cdf_4, x, pdf_3, new Double(kAlpha)); + auto out = addcmul( + cdf_4, x, pdf_3, IrBuilder::create(x->container(), kAlpha)); auto dx = mul(out, dy); return dx; } -namespace { - -//! Transform TensorView according to keep, merge, and split transformations. -//! Trivial reduction and broadcast transformations are handled separately. -//! It is recommend to use the composite ops view function, which will call -//! the analyzeView function to generate the appropriate transformations. -//! -//! For example: -//! original sizes = [2, 10, 40] -//! new_size = [2, 10, 2, 20] -//! auto analysis = analyzeView(TV0, original_sizes, new_sizes) -//! auto TV1 = TV0->view(analysis.transforms); -//! -//! Transforms = [(Keep I0), (Keep I1), (Split I2 by 2)] -//! Before: TV0[I0, I1, I2] -//! After: TV0[I0, I1, 2, ceilDiv(I2, 2)] -//! -TensorView* applyViewTransforms( - TensorView* tv, - const std::vector>& transforms) { - TORCH_INTERNAL_ASSERT( - !tv->hasComputeAt(), - "Cannot modify rfactor domain after compute at has been set."); - - TORCH_INTERNAL_ASSERT(tv->nDims() > 0, "Tried to view a 0-dim TensorView"); - - TORCH_CHECK( - !tv->domain()->hasRFactor(), - "Cannot call view on the same TensorView twice."); - - TORCH_INTERNAL_ASSERT(!transforms.empty()); - - TensorView* consumer = - new TensorView(tv->domain()->view(transforms), tv->getDataType().value()); - - new ViewOp(consumer, tv); - - return consumer; -} - -} // namespace - -TensorView* view( - TensorView* x, - const std::vector& original_sizes, - const std::vector& new_sizes) { - auto analyze_view = analyzeView(x, original_sizes, new_sizes); - - auto reduction = (!analyze_view.trivial_reduction_axes.empty()) - ? sum(x, analyze_view.trivial_reduction_axes) - : x; - - auto view = (!analyze_view.transforms.empty()) - ? applyViewTransforms(reduction, analyze_view.transforms) - : reduction; +Val* tanh_backward(Val* dy, Val* tanh_x) { + TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid."); + TORCH_INTERNAL_ASSERT(tanh_x != nullptr, "Input is invalid"); - return (analyze_view.has_broadcast) - ? broadcast(view, analyze_view.broadcast_axes) - : view; + auto one = IrBuilder::create(tanh_x->container(), 1.); + auto tanh_sq = mul(tanh_x, tanh_x); + auto sub_tanh_sq = sub(one, tanh_sq); + auto dx = mul(dy, sub_tanh_sq); + return dx; } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.h b/torch/csrc/jit/codegen/cuda/ops/composite.h index 4470f0cc6f0..63e17629f40 100644 --- a/torch/csrc/jit/codegen/cuda/ops/composite.h +++ b/torch/csrc/jit/codegen/cuda/ops/composite.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -48,11 +48,7 @@ TORCH_CUDA_CU_API LstmResult lstm( TORCH_CUDA_CU_API Val* fast_gelu(Val* x); TORCH_CUDA_CU_API Val* fast_gelu_backward(Val* dy, Val* x); TORCH_CUDA_CU_API Val* gelu_backward(Val* dy, Val* x); - -TORCH_CUDA_CU_API TensorView* view( - TensorView* x, - const std::vector& x_sizes, - const std::vector& new_sizes); +TORCH_CUDA_CU_API Val* tanh_backward(Val* dy, Val* tanh_x); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp index 19201687553..4a473f66203 100644 --- a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace torch { @@ -23,7 +24,7 @@ TensorView* softmax(TensorView* x, int dim) { auto exp_val = exp(x_max_sub); auto sum_exp = sum(exp_val, {kReductionAxis}); auto bcast_sum = broadcast(sum_exp, broadcast_mask); - auto y = div(exp_val, bcast_sum); + auto y = mul(exp_val, reciprocal(bcast_sum)); return y; } @@ -88,7 +89,7 @@ ForwardNormResult layer_norm( std::vector inner_reduction_axes(kNormShapeNumDims); std::vector inner_broadcast_mask(kNumberOfDims, false); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(x->container(), 1); for (const auto idx : c10::irange(kNormShapeNumDims)) { const size_t axis = kNumberOfDims - 1 - idx; inner_reduction_axes[idx] = axis; @@ -102,7 +103,7 @@ ForwardNormResult layer_norm( auto x_sub_mean = sub(x, mean_bcast); auto var_sum_bcast = broadcast(welford_out.var_sum, inner_broadcast_mask); - auto var = div(var_sum_bcast, num_features); + auto var = mul(var_sum_bcast, reciprocal(num_features)); auto var_eps = add(var, eps); auto invstd = rsqrt(var_eps); @@ -156,7 +157,7 @@ BackwardNormResult layer_norm_backward( std::vector inner_reduction_axes(kNormShapeNumDims); std::vector inner_broadcast_mask(kNumberOfDims, false); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(x->container(), 1); for (const auto idx : c10::irange(kNormShapeNumDims)) { const size_t axis = kNumberOfDims - 1 - idx; inner_reduction_axes[idx] = axis; @@ -243,7 +244,7 @@ ForwardNormResult batch_norm( std::vector reduction_axes; std::vector broadcast_mask(kNumberOfDims, false); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(x->container(), 1); for (const auto axis : c10::irange(kNumberOfDims)) { if (axis != c_axis) { @@ -267,13 +268,15 @@ ForwardNormResult batch_norm( kTraining, "When running stats are provided, batch stats should only be computed during training"); - auto rev_momentum = sub(new Double(1.0), momentum); + auto rev_momentum = + sub(IrBuilder::create(x->container(), 1.0), momentum); auto current_mean_hat = mul(welford_out.avg, momentum); auto mean_hat = mul(running_mean, rev_momentum); auto new_mean_hat = add(mean_hat, current_mean_hat); - auto num_feature_decrement = sub(num_features, new Int(1)); - auto unbiased_var = div(welford_out.var_sum, num_feature_decrement); + auto num_feature_decrement = sub(num_features, x->container()->oneVal()); + auto unbiased_var = + mul(welford_out.var_sum, reciprocal(num_feature_decrement)); auto current_var_hat = mul(unbiased_var, momentum); auto var_hat = mul(running_var, rev_momentum); auto new_var_hat = add(var_hat, current_var_hat); @@ -301,14 +304,14 @@ ForwardNormResult batch_norm( fusion->aliasOutputToInput(casted_output, input_to_cast); }; - if (fusion->hasInput(running_mean)) { + if (running_mean->isFusionInput()) { fusion->addOutput(new_mean_hat); fusion->aliasOutputToInput(new_mean_hat, running_mean); } else { cast_to_input_dtype(running_mean, new_mean_hat); } - if (fusion->hasInput(running_var)) { + if (running_var->isFusionInput()) { fusion->addOutput(new_var_hat); fusion->aliasOutputToInput(new_var_hat, running_var); } else { @@ -320,7 +323,7 @@ ForwardNormResult batch_norm( auto mean_bcast = broadcast(mean, broadcast_mask); auto x_sub_mean = sub(x, mean_bcast); - auto var = div(welford_out.var_sum, num_features); + auto var = mul(welford_out.var_sum, reciprocal(num_features)); auto var_eps = add(var, eps); invstd = rsqrt(var_eps); auto invstd_bcast = broadcast(invstd, broadcast_mask); @@ -414,19 +417,6 @@ BackwardNormResult batch_norm_backward( mean = broadcast(mean, broadcast_mask); - TensorView* weight_val = nullptr; - if (weight == nullptr) { - weight_val = TensorViewBuilder() - .ndims(kNumberOfDims) - .dtype(input->getDataType().value()) - .shape(std::vector(kNumberOfDims, 1)) - .build(); - new UnaryOp( - UnaryOpType::Set, weight_val->as(), (new Double(1.0))->as()); - } else { - weight_val = broadcast(weight, broadcast_mask); - } - auto norm = reciprocal(num_features); auto grad_output_sum = sum(grad_output, reduction_axes); @@ -435,7 +425,16 @@ BackwardNormResult batch_norm_backward( auto grad_mean = broadcast(mul(grad_output_sum, norm), broadcast_mask); auto proj_scale = broadcast(mul(mul(dot_p, norm), mul(invstd, invstd)), broadcast_mask); - auto grad_scale = mul(broadcast(invstd, broadcast_mask), weight_val); + TensorView* grad_scale = nullptr; + + if (weight == nullptr) { + grad_scale = + mul(broadcast(invstd, broadcast_mask), + IrBuilder::create(input->container(), 1)); + } else { + grad_scale = mul( + broadcast(invstd, broadcast_mask), broadcast(weight, broadcast_mask)); + } TensorView* grad_input = nullptr; if (kTraining) { @@ -496,7 +495,7 @@ ForwardNormResult instance_norm( std::vector x_reduction_axes; std::vector x_broadcast_mask(kNumberOfDims, false); - Val* N = new Double(1); + Val* N = IrBuilder::create(x->container(), 1); for (const auto axis : c10::irange(kNumberOfDims)) { if (axis != kBatchDim && axis != kChannelsDim) { x_reduction_axes.push_back(axis); @@ -504,7 +503,7 @@ ForwardNormResult instance_norm( N = mul(N, x->domain()->domain()[axis]->extent()); } } - Val* B = new Double(1); + Val* B = IrBuilder::create(x->container(), 1); B = mul(B, x->domain()->domain()[kBatchDim]->extent()); std::vector channels_only_broadcast_mask(kNumberOfDims, false); @@ -523,7 +522,8 @@ ForwardNormResult instance_norm( // updating running mean and running var if (running_mean != nullptr && running_var != nullptr) { - auto rev_momentum = sub(new Double(1.0), momentum); + auto rev_momentum = + sub(IrBuilder::create(x->container(), 1.0), momentum); auto current_mean_hat = mul(welford_out.avg, momentum); auto mean_hat = mul(running_mean, rev_momentum); auto new_mean_hat = add(mean_hat, current_mean_hat); @@ -531,12 +531,13 @@ ForwardNormResult instance_norm( // NS: static_cast to workaround VC++ error, see // https://godbolt.org/z/6Prd77xYs auto new_mean_sum = sum(new_mean_hat, {static_cast(kBatchDim)}); - auto new_mean_channels_only = div(new_mean_sum, B); + auto new_mean_channels_only = mul(new_mean_sum, reciprocal(B)); fusion->addOutput(new_mean_channels_only); fusion->aliasOutputToInput(new_mean_channels_only, running_mean); - auto num_feature_decrement = sub(N, new Int(1)); - auto unbiased_var = div(welford_out.var_sum, num_feature_decrement); + auto num_feature_decrement = sub(N, x->container()->oneVal()); + auto unbiased_var = + mul(welford_out.var_sum, reciprocal(num_feature_decrement)); auto current_var_hat = mul(unbiased_var, momentum); auto var_hat = mul(running_var, rev_momentum); auto new_var_hat = add(var_hat, current_var_hat); @@ -544,7 +545,7 @@ ForwardNormResult instance_norm( // NS: static_cast to workaround VC++ error, see // https://godbolt.org/z/6Prd77xYs auto new_var_sum = sum(new_var_hat, {static_cast(kBatchDim)}); - auto new_var_channels_only = div(new_var_sum, B); + auto new_var_channels_only = mul(new_var_sum, reciprocal(B)); fusion->addOutput(new_var_channels_only); fusion->aliasOutputToInput(new_var_channels_only, running_var); } @@ -553,7 +554,7 @@ ForwardNormResult instance_norm( auto mean_bcast = broadcast(mean, x_broadcast_mask); auto x_sub_mean = sub(x, mean_bcast); - auto var = div(welford_out.var_sum, N); + auto var = mul(welford_out.var_sum, reciprocal(N)); auto var_eps = add(var, eps); invstd = rsqrt(var_eps); auto invstd_bcast = broadcast(invstd, x_broadcast_mask); diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.h b/torch/csrc/jit/codegen/cuda/ops/normalization.h index dae58462b92..b28cdf6b33c 100644 --- a/torch/csrc/jit/codegen/cuda/ops/normalization.h +++ b/torch/csrc/jit/codegen/cuda/ops/normalization.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp index 3dcb58335a4..d966fc21a97 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp +++ b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include #include @@ -102,7 +100,6 @@ void ParallelDimensionMap::populateDimensionMapWithSingleCASet( TORCH_INTERNAL_ASSERT(dom_set.size() == 1); const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // pt is used by only one concrete domain auto id = *dom_set.begin(); @@ -110,16 +107,16 @@ void ParallelDimensionMap::populateDimensionMapWithSingleCASet( if (it != constant_extent_map_.end()) { if (it->second.size() == 1) { - dim_map_.insert({pt, ir_builder.create(*(it->second.begin()))}); + dim_map_.insert({pt, IrBuilder::create(*(it->second.begin()))}); exact_types_.insert(pt); } else { // Multiple constant dimensions found; Use the corresponding // symbolic parallel dim - dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)}); + dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); } } else { // Prefer to use blockDim/gridDim if not constant - dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)}); + dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); exact_types_.insert(pt); } } @@ -130,11 +127,10 @@ void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( TORCH_INTERNAL_ASSERT(dom_set.size() > 1); const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); bool all_equal = true; // Use nullptr to signal it's not initialied yet - kir::Val* known_dimension = nullptr; + Val* known_dimension = nullptr; // Use -1 to signal it's not initialied yet int64_t known_const = -1; @@ -172,7 +168,7 @@ void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( // At this point, it still remains undetermined whether this id // matches with those previously looked at. Constant check failed, // but symbolic matching may succeed. - auto this_dimension = gpu_lower->lowerValue(concrete_id->extent()); + auto this_dimension = concrete_id->extent(); if (known_dimension == nullptr) { // No previous dimension found yet known_dimension = this_dimension; @@ -191,15 +187,14 @@ void ParallelDimensionMap::populateDimensionMapWithMultipleCASet( } // Use the const value, if found, as its dimension if (all_equal && known_const != -1) { - dim_map_.insert({pt, ir_builder.create(known_const)}); + dim_map_.insert({pt, IrBuilder::create(known_const)}); } else { - dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)}); + dim_map_.insert({pt, NamedScalar::getParallelDim(pt)}); } } void ParallelDimensionMap::adjustMappingsForWarpPadding() { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); // If TIDx is padded to a multiple of the warp size, mark it as // non-exact. @@ -215,7 +210,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { // If the dimension of TIDx is actually a multple of the warp size // before padding, it can be left as exact if (isExact(tidx_pt)) { - auto tidx_dim = dynamic_cast(get(tidx_pt)); + auto tidx_dim = dynamic_cast(get(tidx_pt)); if (tidx_dim && tidx_dim->isConst()) { auto tidx_dim_val = tidx_dim->value().value(); if (tidx_dim_val % warp_size == 0) { @@ -229,17 +224,17 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { // single warp, use the constant warp size as the dimension of // TIDx. Otherwise, jsut use blockDim.x. if (warp_info.is_tidx_single_warp) { - dim_map_.at(ParallelType::TIDx) = ir_builder.create(warp_size); + dim_map_.at(ParallelType::TIDx) = IrBuilder::create(warp_size); } else { dim_map_.at(ParallelType::TIDx) = - kir::NamedScalar::getParallelDim(ParallelType::TIDx); + NamedScalar::getParallelDim(ParallelType::TIDx); } // TIDx is no longer exact exact_types_.erase(ParallelType::TIDx); } -kir::Val* ParallelDimensionMap::get(ParallelType pt) const { +Val* ParallelDimensionMap::get(ParallelType pt) const { TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt); auto it = dim_map_.find(pt); if (it == dim_map_.end()) { @@ -261,7 +256,7 @@ IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) { // Symbolically compares equality of two KIR vals. Comparison is done // conservatively, so returning false does not guarantee non-equality. -bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { +bool ParallelDimensionMap::equalDim(Val* dim1, Val* dim2) { TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr); if (dim1 == dim2) { @@ -269,8 +264,8 @@ bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { } // When Both are Int, they are same if both have the same constant - auto dim1_int = dynamic_cast(dim1); - auto dim2_int = dynamic_cast(dim2); + auto dim1_int = dynamic_cast(dim1); + auto dim2_int = dynamic_cast(dim2); if (dim1_int && dim2_int) { if (dim1_int->isConst() && dim2_int->isConst()) { return dim1_int->value() == dim2_int->value(); @@ -279,8 +274,8 @@ bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { // When both are NamedScalar, they are same if Both have the same // name - auto dim1_ns = dynamic_cast(dim1); - auto dim2_ns = dynamic_cast(dim2); + auto dim1_ns = dynamic_cast(dim1); + auto dim2_ns = dynamic_cast(dim2); if (dim1_ns && dim2_ns) { return dim1_ns->name() == dim2_ns->name(); } @@ -297,12 +292,12 @@ bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) { // If both are BinaryOp or UnaryOp, check their inputs. Since these // Vals are IterDomain extents, UnaryOp should not occur, but // checking shouldn't be harmful. - if ((dim1_def->isA() && dim2_def->isA() && - (dim1_def->as()->operation() == - dim2_def->as()->operation())) || - (dim1_def->isA() && dim2_def->isA() && - (dim1_def->as()->operation() == - dim2_def->as()->operation()))) { + if ((dim1_def->isA() && dim2_def->isA() && + (dim1_def->as()->getBinaryOpType() == + dim2_def->as()->getBinaryOpType())) || + (dim1_def->isA() && dim2_def->isA() && + (dim1_def->as()->getUnaryOpType() == + dim2_def->as()->getUnaryOpType()))) { for (const auto i : c10::irange(dim1_def->inputs().size())) { (void)i; // Suppress unused variable warning if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) { @@ -321,7 +316,7 @@ std::string ParallelDimensionMap::toString() const { ss << pt << ": "; auto dim = get(pt); if (dim != nullptr) { - ss << kir::toString(dim); + ss << dim->toString(); if (isExact(pt)) { ss << ", exact"; } else { diff --git a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h index d05c17adea2..03bd513396f 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h +++ b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h @@ -21,7 +21,7 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { //! Returns the dimension of a ParallelType. nullptr is returned if //! a ParallelType is unused. - kir::Val* get(ParallelType pt) const; + Val* get(ParallelType pt) const; //! True if the dimension of a ParallelType is known to be exact bool isExact(ParallelType pt) const; @@ -29,7 +29,7 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { std::string toString() const; //! Symbolically analyze if two extent vals are equal - static bool equalDim(kir::Val* dim1, kir::Val* dim2); + static bool equalDim(Val* dim1, Val* dim2); private: //! Register the extent of an IterDomain if its constant @@ -54,7 +54,7 @@ class TORCH_CUDA_CU_API ParallelDimensionMap { private: //! Maps from parallel types to dimensions, which are constant if //! a unique value is found. - std::unordered_map dim_map_; + std::unordered_map dim_map_; //! Set of parallel types whose dimensions are identified to be //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; diff --git a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h index 0bf8ae39277..3bfb32d38bc 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h +++ b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 11c27cffec2..94dad076db8 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -33,25 +34,18 @@ constexpr auto kNumBinaryFloatOps = 3; constexpr auto kNumBinaryComparisonOps = 12; constexpr auto kNumBinaryCastOps = 14; -constexpr auto kNumBinaryOpsWithAlpha = 4; +constexpr auto kNumBinaryOpsWithAlpha = 6; constexpr auto kNumLerpOps = 2; constexpr auto kNumLayernormFwd = 2; constexpr auto kNumBatchnormFwd = 3; constexpr auto kNumInstancenormFwd = 1; constexpr auto kNumSumToSize = 2; constexpr auto kNumAutocastOps = 2; -// constexpr auto kNumViewSize = 2; +constexpr auto kNumAliasDimOps = 2; +constexpr auto kNumViewOps = 2; namespace { -std::vector getTensorSizes(TensorTypePtr const& tensor_type) { - TORCH_INTERNAL_ASSERT(tensor_type != nullptr, "Input must be a Tensor."); - auto optional_sizes = tensor_type->sizes().concrete_sizes(); - TORCH_INTERNAL_ASSERT( - optional_sizes.has_value(), "Missing size information for the tensor."); - return optional_sizes.value(); -} - #define REGISTER_PARSE_RULE(op, func_body, ...) \ registerParseRule( \ op, \ @@ -59,7 +53,8 @@ std::vector getTensorSizes(TensorTypePtr const& tensor_type) { -> void func_body, \ __VA_ARGS__) -const auto& sizeAttr = Symbol::attr("profiled_size"); +const auto& reductionSizeAttr = Symbol::attr("profiled_reduction_size"); +const auto& viewSizeAttr = Symbol::attr("profiled_view_size"); const auto& intListAttr = Symbol::attr("profiled_int_list"); const auto& intAttr = Symbol::attr("profiled_int"); const auto& boolListAttr = Symbol::attr("profiled_bool_list"); @@ -283,8 +278,9 @@ class ValueHolder { if (iter_val != vals_.end()) { return iter_val->second; } - // patching scalar value, because memory format doesn't carry real meaning. - if (!is_tensor_view_) { + // patching scalar (tensor), memory format doesn't carry meaning and should + // just return the value as-is. + if (!is_tensor_view_ || rank() == 0) { return std::get<1>(getEntry()); } MemoryFormat format_s; @@ -505,7 +501,7 @@ class IrParser { "Failure when register value: ", *(val->node()), " with type: ", - val->type()); + val->type()->repr_str()); MemoryFormat format; Val* operand = nullptr; std::tie(format, operand) = value_map_[val->unique()].getEntry(); @@ -523,7 +519,6 @@ class IrParser { (opt_dtype.value() == DataType::Half || opt_dtype.value() == DataType::BFloat16)) { Val* promoted_val = castOp(DataType::Float, operand); - // value_map_.emplace(val->unique(), ValueHolder(promoted_val, format)); value_map_[val->unique()] = ValueHolder(promoted_val, format); } } @@ -688,7 +683,9 @@ class IrParser { "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", - "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"}; + "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", + "aten::rsub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::rsub(Tensor self, Scalar other, Scalar alpha) -> Tensor"}; for (auto signature : BinaryOpWithAlpha) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( @@ -704,6 +701,10 @@ class IrParser { BinaryOpType::Add, static_cast(&add_alpha))}, {aten::sub, + std::make_pair( + BinaryOpType::Sub, + static_cast(&sub_alpha))}, + {aten::rsub, std::make_pair( BinaryOpType::Sub, static_cast(&sub_alpha))}}); @@ -723,10 +724,12 @@ class IrParser { auto out = alpha->isOneInt() ? binaryOp( op_mapping[node->kind()].first, - lhs, - rhs, + node->kind() == aten::rsub ? rhs : lhs, + node->kind() == aten::rsub ? lhs : rhs, TypePromotion::default_op_config) - : op_mapping[node->kind()].second(lhs, rhs, alpha); + : (node->kind() == aten::rsub + ? op_mapping[node->kind()].second(rhs, lhs, alpha) + : op_mapping[node->kind()].second(lhs, rhs, alpha)); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -1101,10 +1104,10 @@ class IrParser { list_val.pop_front(); Val* low = value_map.count(node->inputs()[1]->unique()) != 0 ? *value_map[node->inputs()[1]->unique()] - : new Double(std::numeric_limits::min()); + : IrBuilder::create(std::numeric_limits::min()); Val* high = value_map.count(node->inputs()[2]->unique()) != 0 ? *value_map[node->inputs()[2]->unique()] - : new Double(std::numeric_limits::max()); + : IrBuilder::create(std::numeric_limits::max()); auto out = clamp(operand, low, high); value_map.emplace(node->output()->unique(), out); @@ -1340,7 +1343,7 @@ class IrParser { running_mean = value_map[node->input(3)->unique()]->as(); TORCH_INTERNAL_ASSERT( - fusion->hasInput(running_mean), + running_mean->isFusionInput(), "IO_tensor `instance_norm::running_mean` can only be input tensor to fusion"); } @@ -1350,7 +1353,7 @@ class IrParser { running_var = value_map[node->input(4)->unique()]->as(); TORCH_INTERNAL_ASSERT( - fusion->hasInput(running_var), + running_var->isFusionInput(), "IO_tensor `instance_norm::running_var` can only be input tensor to fusion"); } @@ -1364,7 +1367,7 @@ class IrParser { Val* momentum_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto momentum = constant_as(node->input(6))) { - momentum_ptr = new Double(momentum.value()); + momentum_ptr = IrBuilder::create(momentum.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) momentum_ptr = value_map[node->input(6)->unique()]; @@ -1373,7 +1376,7 @@ class IrParser { Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(7))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; @@ -1458,7 +1461,7 @@ class IrParser { Val* momentum_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto momentum = constant_as(node->input(6))) { - momentum_ptr = new Double(momentum.value()); + momentum_ptr = IrBuilder::create(momentum.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) momentum_ptr = value_map[node->input(6)->unique()]; @@ -1467,7 +1470,7 @@ class IrParser { Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(7))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; @@ -1586,7 +1589,7 @@ class IrParser { Val* eps_ptr = nullptr; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) if (auto eps = constant_as(node->input(9))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) eps_ptr = value_map[node->input(7)->unique()]; @@ -1704,7 +1707,7 @@ class IrParser { Val* eps_ptr = nullptr; if (auto eps = constant_as(node->input(4))) { - eps_ptr = new Double(eps.value()); + eps_ptr = IrBuilder::create(eps.value()); } else { eps_ptr = value_map[node->input(4)->unique()]; } @@ -2032,7 +2035,7 @@ class IrParser { keepdim.has_value(), "aten::mean cannot be fused with dynamic keepdim"); auto o_sum = sum(self, dims, keepdim.value()); - Val* num_features = new Double(1); + Val* num_features = IrBuilder::create(1); for (auto axis : dims) { if (axis < 0) { axis += int(self->nDims()); @@ -2347,6 +2350,31 @@ class IrParser { nullptr); } + { + auto ptr_op = getOperatorForLiteral( + "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, + value_map[node->inputs()[0]->unique()], + value_map[node->inputs()[1]->unique()]); + auto grad_out = list_val.front(); + list_val.pop_front(); + auto self = list_val.front(); + list_val.pop_front(); + + auto grad_in = tanh_backward(grad_out, self); + value_map.emplace( + node->output()->unique(), ValueHolder(grad_in, format)); + }, + nullptr, + nullptr); + } + { auto ptr_op = getOperatorForLiteral( "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"); @@ -2392,37 +2420,111 @@ class IrParser { }); } - /* - // TODO: Enable view in parser by detecting non-alias view operation { - std::array View = { - "aten::view(Tensor(a) self, int[] size) -> Tensor(a)", - "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"}; - for (auto signature : View) { + std::array ViewOps = { + "prim::reshape_copy(Tensor self, int[] shape) -> Tensor", + "prim::view_copy(Tensor self, int[] size) -> Tensor"}; + for (auto signature : ViewOps) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( ptr_op, { auto self_value = node->inputs()[0]; - auto self = value_map[self_value->unique()]->as(); + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + MemoryFormat::Contiguous(), value_map[self_value->unique()]); + auto self = list_val.front()->as(); + list_val.pop_front(); auto self_type = self_value->type()->cast(); TORCH_INTERNAL_ASSERT(self_type != nullptr); auto self_sizes = getTensorSizes(self_type); - auto size_optional = - constant_as>(node->input(1)); + auto view_sizes = constant_as>(node->input(1)); TORCH_INTERNAL_ASSERT( - size_optional.has_value(), "The size parameter is required."); + view_sizes.has_value(), "The size parameter is required."); - auto output = view(self, self_sizes, size_optional->vec()); + auto output = view(self, self_sizes, view_sizes->vec()); + value_map.emplace(node->output()->unique(), output); + }, + [](const Node* node) -> bool { + // Reject fusing node if view_sizes contains an inferred dimension + auto view_sizes = constant_as>(node->input(1)); + TORCH_INTERNAL_ASSERT( + view_sizes.has_value(), "The size parameter is required."); + for (auto axis_size : view_sizes->vec()) { + if (axis_size == -1) { + return false; + } + } + return true; + }, + nullptr); + } + } + + { + auto ptr_op = + getOperatorForLiteral("prim::squeeze_copy(Tensor self) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + auto self_value = node->inputs()[0]; + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + MemoryFormat::Contiguous(), value_map[self_value->unique()]); + auto self = list_val.front()->as(); + list_val.pop_front(); + + auto self_type = self_value->type()->cast(); + TORCH_INTERNAL_ASSERT(self_type != nullptr); + auto self_sizes = getTensorSizes(self_type); + + auto output = squeeze(self, self_sizes); + value_map.emplace(node->output()->unique(), output); + }, + nullptr, + nullptr); + } + + { + std::array AliasOpWithDim = { + "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor", + "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor"}; + for (auto signature : AliasOpWithDim) { + auto ptr_op = getOperatorForLiteral(signature); + REGISTER_PARSE_RULE( + ptr_op, + { + auto self_value = node->inputs()[0]; + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + MemoryFormat::Contiguous(), + value_map[node->inputs()[0]->unique()]); + auto self = list_val.front()->as(); + list_val.pop_front(); + + auto dim_value = constant_as(node->input(1)); + TORCH_INTERNAL_ASSERT(dim_value.has_value(), "dim is not valid"); + + TensorView* output = nullptr; + if (node->kind() == prim::unsqueeze_copy) { + output = unsqueeze(self, dim_value.value()); + } else { + auto self_type = self_value->type()->cast(); + TORCH_INTERNAL_ASSERT(self_type != nullptr); + auto self_sizes = getTensorSizes(self_type); + output = squeeze(self, self_sizes, dim_value.value()); + } value_map.emplace(node->output()->unique(), output); }, nullptr, nullptr); } } - */ } void processJitNode(const JitOp* node) { @@ -2456,9 +2558,9 @@ class IrParser { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { - cg_val = new Double(ival.value()); + cg_val = IrBuilder::create(ival.value()); } else { - cg_val = new Double(); + cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; @@ -2467,9 +2569,9 @@ class IrParser { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { - cg_val = new Int(ival.value()); + cg_val = IrBuilder::create(ival.value()); } else { - cg_val = new Int(); + cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; @@ -2478,9 +2580,9 @@ class IrParser { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CgValue cg_val; if (auto ival = constant_as(val)) { - cg_val = new Bool(ival.value()); + cg_val = IrBuilder::create(ival.value()); } else { - cg_val = new Bool(); + cg_val = IrBuilder::create(); } value_map_.emplace(val->unique(), cg_val); return true; @@ -2496,7 +2598,11 @@ class IrParser { // TODO: we don't support list type in codegen yet; // This is a WAR to allow axes of reduction to be passed as constant list; // We simply ignore conversion if the scalar value is a constant; - return toIValue(val).has_value(); + auto ivalue = toIValue(val); + TORCH_INTERNAL_ASSERT( + ivalue.has_value(), + "List[T] is not supported as an argument by NvFuser. Use a Constant List."); + return true; } return false; } @@ -2566,7 +2672,10 @@ class IrParser { tensor_type->undefined()); } - cg_val = new TensorView(tensor_type); + cg_val = IrBuilder::create(tensor_type); + if (is_cpu_scalar(*tensor_type)) { + cg_val->as()->setCpuScalar(true); + } value_map_.emplace(val->unique(), ValueHolder(cg_val, format)); return true; } @@ -2611,7 +2720,7 @@ ProfileIValueOp* insertProfileIValueOp( return pn; } -void profileSize(ProfilingRecord* pr, Node* node, size_t offset) { +void profileReductionSize(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); const auto ivalue_profiler = [pr, pn](Stack& stack) { @@ -2631,12 +2740,14 @@ void profileSize(ProfilingRecord* pr, Node* node, size_t offset) { size_vec.clear(); } else { TORCH_INTERNAL_ASSERT( - false, "profileSize does not support data type: ", value.tagKind()); + false, + "profileReductionSize does not support data type: ", + value.tagKind()); } - if (!pn->hasAttribute(sizeAttr)) { - pn->is_(sizeAttr, size_vec); + if (!pn->hasAttribute(reductionSizeAttr)) { + pn->is_(reductionSizeAttr, size_vec); } else { - auto profiled_ints = pn->is(sizeAttr); + auto profiled_ints = pn->is(reductionSizeAttr); TORCH_INTERNAL_ASSERT( profiled_ints.size() == size_vec.size() && std::equal( @@ -2648,6 +2759,39 @@ void profileSize(ProfilingRecord* pr, Node* node, size_t offset) { pn->setCallback(ivalue_profiler); } +void profileViewSize(ProfilingRecord* pr, Node* node, size_t offset) { + auto pn = insertProfileIValueOp(node, offset, pr); + + const auto ivalue_profiler = [pr, pn](Stack& stack) { + std::lock_guard lock(pr->mutex_); + + // TODO: we don't care about merging multiple profiling runs as we don't + // support it at all; + int64_t frame_id = 0; + pop(stack, frame_id); + IValue value; + pop(stack, value); + TORCH_INTERNAL_ASSERT( + value.isIntList(), "profiling seeing the wrong data type"); + if (!pn->hasAttribute(viewSizeAttr)) { + pn->is_(viewSizeAttr, value.toIntVector()); + } else { + auto profiled_ints = pn->is(viewSizeAttr); + auto input_ints = value.toIntList(); + TORCH_INTERNAL_ASSERT( + profiled_ints.size() == input_ints.size() && + std::equal( + profiled_ints.begin(), + profiled_ints.end(), + input_ints.begin()), + "profiling ivalue doesn't support merge"); + } + push(stack, value); + }; + + pn->setCallback(ivalue_profiler); +} + void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) { auto pn = insertProfileIValueOp(node, offset, pr); @@ -2943,7 +3087,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { // argument 1: reduction sizes; case 1: // TODO(profile_size): double check optional[size]? - profileSize(pr, node, offset); + profileReductionSize(pr, node, offset); break; default: return false; @@ -2951,28 +3095,52 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { return true; } - /* - // TODO: Enable view in parser by detecting non-alias view operation - static auto view_schema = + static auto reshape_schema = + getOperatorForLiteral("aten::reshape(Tensor self, int[] shape) -> Tensor") + ->schema(); + static auto reshape_copy_schema = getOperatorForLiteral( - "aten::view(Tensor(a) self, int[] size) -> Tensor(a)") + "prim::reshape_copy(Tensor self, int[] shape) -> Tensor") ->schema(); - static auto reshape_schema = + static auto view_schema = + getOperatorForLiteral("aten::view(Tensor self, int[] size) -> Tensor") + ->schema(); + static auto view_copy_schema = getOperatorForLiteral( - "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)") + "prim::view_copy(Tensor self, int[] size) -> Tensor") ->schema(); - if (node->matches(view_schema) || node->matches(reshape_schema)) { + if (node->matches(reshape_schema) || node->matches(reshape_copy_schema) || + node->matches(view_schema) || node->matches(view_copy_schema)) { switch (offset) { // argument 1: new tensor size; case 1: - profileSize(pr, node, offset); + profileViewSize(pr, node, offset); + break; + default: + return false; + } + return true; + } + + static auto squeeze_dim_schema = + getOperatorForLiteral( + "prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor") + ->schema(); + static auto unsqueeze_schema = + getOperatorForLiteral( + "prim::unsqueeze_copy(Tensor self, int dim) -> Tensor") + ->schema(); + if (node->matches(squeeze_dim_schema) || node->matches(unsqueeze_schema)) { + switch (offset) { + // argument 1: unsqueeze dim; + case 1: + profileInt(pr, node, offset); break; default: return false; } return true; } - */ static auto batch_norm_impl_index_schema = getOperatorForLiteral( diff --git a/torch/csrc/jit/codegen/cuda/parser.h b/torch/csrc/jit/codegen/cuda/parser.h index 4b2fcf50f99..6d52b325042 100644 --- a/torch/csrc/jit/codegen/cuda/parser.h +++ b/torch/csrc/jit/codegen/cuda/parser.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/partial_split_map.cpp b/torch/csrc/jit/codegen/cuda/partial_split_map.cpp index e7b6db4d165..e320e8ee373 100644 --- a/torch/csrc/jit/codegen/cuda/partial_split_map.cpp +++ b/torch/csrc/jit/codegen/cuda/partial_split_map.cpp @@ -12,7 +12,7 @@ void PartialSplitMap::build(Fusion* fusion) { auto used_vals = ir_utils::allTvs(fusion); for (auto tv : ir_utils::filterByType(used_vals)) { - auto exprs = ExprSort::getExprs( + auto exprs = StmtSort::getExprs( fusion, {tv->domain()->domain().begin(), tv->domain()->domain().end()}); for (auto split : ir_utils::filterByType(exprs)) { // Only needs to check root domains as partial split is only @@ -24,18 +24,10 @@ void PartialSplitMap::build(Fusion* fusion) { continue; } auto root_domain = split->in(); - auto kir_root_domain = - gpu_lower->lowerValue(split->in())->as(); auto start_offset = split->startOffset(); start_offset_map_.insert({root_domain, start_offset}); - kir_start_offset_map_.insert( - {kir_root_domain, - gpu_lower->lowerValue(start_offset)->as()}); auto stop_offset = split->stopOffset(); stop_offset_map_.insert({root_domain, stop_offset}); - kir_stop_offset_map_.insert( - {kir_root_domain, - gpu_lower->lowerValue(stop_offset)->as()}); } } } @@ -49,15 +41,6 @@ Val* PartialSplitMap::getStartOffset(IterDomain* root_domain) const { } } -kir::Val* PartialSplitMap::getStartOffset(kir::IterDomain* root_domain) const { - auto it = kir_start_offset_map_.find(root_domain); - if (it == kir_start_offset_map_.end()) { - return nullptr; - } else { - return it->second; - } -} - Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const { auto it = stop_offset_map_.find(root_domain); if (it == stop_offset_map_.end()) { @@ -67,15 +50,6 @@ Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const { } } -kir::Val* PartialSplitMap::getStopOffset(kir::IterDomain* root_domain) const { - auto it = kir_stop_offset_map_.find(root_domain); - if (it == kir_stop_offset_map_.end()) { - return nullptr; - } else { - return it->second; - } -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/partial_split_map.h b/torch/csrc/jit/codegen/cuda/partial_split_map.h index be432bd5a16..8ec489915b7 100644 --- a/torch/csrc/jit/codegen/cuda/partial_split_map.h +++ b/torch/csrc/jit/codegen/cuda/partial_split_map.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -20,15 +20,11 @@ class TORCH_CUDA_CU_API PartialSplitMap { void build(Fusion* fusion); Val* getStartOffset(IterDomain* root_domain) const; - kir::Val* getStartOffset(kir::IterDomain* root_domain) const; Val* getStopOffset(IterDomain* root_domain) const; - kir::Val* getStopOffset(kir::IterDomain* root_domain) const; private: std::unordered_map start_offset_map_; - std::unordered_map kir_start_offset_map_; std::unordered_map stop_offset_map_; - std::unordered_map kir_stop_offset_map_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp index 004c836ec4e..91d68494fd4 100644 --- a/torch/csrc/jit/codegen/cuda/partition.cpp +++ b/torch/csrc/jit/codegen/cuda/partition.cpp @@ -5,12 +5,15 @@ #include #include #include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { +const c10::DeviceIndex INVALID_INDEX = -2; + namespace { bool hasNonElementWiseOperation(const Node* node) { @@ -38,26 +41,61 @@ static c10::optional getDevice(const Value* value) { // not tensor type, return false as the op is not outputing scalar. return c10::nullopt; } - return value->type()->expectRef().device(); + auto tensor_type = value->type()->expectRef(); + // special case for scalar tensor: return c10::nullopt instead of cpu device. + // this allows us to fuse scalar cpu tensor with cuda tensor, while avoid + // merging ops with pure scalar cpu tensors. + if (is_cpu_scalar(tensor_type)) { + return c10::nullopt; + } + return tensor_type.device(); } static c10::optional getDevice(const Node* node) { - auto outputs = node->outputs(); - for (auto output : outputs) { - auto device = getDevice(output); + c10::optional ret = c10::nullopt; + auto merge_devices = [&ret](const c10::optional& device) { if (device.has_value()) { - return device; + if (ret.has_value()) { + if (ret.value() != device.value()) { + // invalidate device to reflect conflicts + ret->set_index(INVALID_INDEX); + // return false to indicate early termination + return false; + } else { + // same device, do nothing + return true; + } + } else { + // initialize return device + ret = device.value(); + return true; + } + } + // no device information, do nothing + return true; + }; + for (auto val : node->inputs()) { + if (!merge_devices(getDevice(val))) { + return ret; } } - return c10::nullopt; + for (auto val : node->outputs()) { + if (!merge_devices(getDevice(val))) { + return ret; + } + } + return ret; } static bool isFusibleDevice(const Node* node, const c10::Device device) { - for (auto value : node->outputs()) { - auto output_device = getDevice(value); - if (output_device.has_value() && output_device.value() != device) { - return false; - } + TORCH_INTERNAL_ASSERT( + device.index() != INVALID_INDEX, "fusible device needs to be validate"); + auto opt_device = getDevice(node); + // we can be more relaxed here as we known that this function tries to merge + // node into an existing `device` + if (opt_device.has_value() && + (opt_device->index() == INVALID_INDEX || opt_device != device)) { + return false; } return true; } @@ -65,10 +103,12 @@ static bool isFusibleDevice(const Node* node, const c10::Device device) { // TODO: we need to check input type when we handle `to()` static bool isFusibleDevice(const Node* node) { auto device = getDevice(node); + // be conservative and only fuse cuda operations, this avoids us initializing + // operations that produces cpu scalar outputs if (!device.has_value()) { - return true; + return false; } - return device->is_cuda() && + return device->index() != INVALID_INDEX && device->is_cuda() && (at::cuda::getDeviceProperties(device->index())->major >= 7 || !hasNonElementWiseOperation(node)); } @@ -400,7 +440,7 @@ bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) { bool fused = false; // TODO: lift the restriction of not fusing producer containing reduction when // we have proper scheduling. - if (isFusibleCudaFusionGroup(node)) { + if (isFusibleNode(node)) { // ensure if the node has a designated device, it's on the same device with // fusion. // TODO: is there a danger of us fusing operations that's supposed to be on @@ -408,7 +448,6 @@ bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) { auto device = getDevice(fusion); fused = (!device.has_value() || isFusibleDevice(node, device.value())); } - return fused; } diff --git a/torch/csrc/jit/codegen/cuda/partition.h b/torch/csrc/jit/codegen/cuda/partition.h index 0d8baca4700..b295cb582e5 100644 --- a/torch/csrc/jit/codegen/cuda/partition.h +++ b/torch/csrc/jit/codegen/cuda/partition.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include /* diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index b501a6133f6..6575b374423 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include @@ -20,27 +18,23 @@ namespace cuda { namespace { -bool isTensorIndexOp(kir::Expr* expr) { +bool isTensorIndexOp(Expr* expr) { const auto& outputs = expr->outputs(); return outputs.size() >= 1 && outputs[0]->isA(); } -bool isOutputLocal(const kir::Expr* expr) { +bool isOutputLocal(const Expr* expr) { return std::all_of( - expr->outputs().begin(), - expr->outputs().end(), - [](const kir::Val* output) { - return !output->isA() || - output->as()->memoryType() == MemoryType::Local; + expr->outputs().begin(), expr->outputs().end(), [](const Val* output) { + return !output->isA() || + output->as()->getMemoryType() == MemoryType::Local; }); } } // namespace -bool ParallelizedDomainPredicate::PredicateInfo::addDomain( - kir::IterDomain* id) { - const auto gpu_lower = GpuLower::current(); - auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(id); +bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { + auto concrete_id = GpuLower::current()->caIndexMap().getConcreteMappedID(id); if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { ids_.push_back(concrete_id); return true; @@ -49,21 +43,19 @@ bool ParallelizedDomainPredicate::PredicateInfo::addDomain( } } -kir::Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { - const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); - - kir::Bool* pred = nullptr; +Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { + Bool* pred = nullptr; - auto index = - ir_builder.create(stringifyThread(pt_), DataType::Int); + auto index = SimplifyingIrBuilder::create( + stringifyThread(pt_), DataType::Int); for (const auto& pred_id : ids()) { // Just sanity check that pred_id is concrete TORCH_INTERNAL_ASSERT( - pred_id == gpu_lower->caIndexMap().getConcreteMappedID(pred_id)); - auto new_pred = ir_builder.ltExpr(index, pred_id->extent()); - pred = ir_builder.andExpr(pred, new_pred)->as(); + pred_id == + GpuLower::current()->caIndexMap().getConcreteMappedID(pred_id)); + auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent()); + pred = SimplifyingIrBuilder::andExpr(pred, new_pred)->as(); } return pred; @@ -74,16 +66,12 @@ namespace { std::unordered_set getNonUnswitchedRootDomains( const std::vector& loops, size_t unswitched_loop_index) { - const auto gpu_lower = GpuLower::current(); - std::vector non_unswited_leaf_domains; std::transform( loops.begin(), loops.begin() + unswitched_loop_index, std::back_inserter(non_unswited_leaf_domains), - [&](kir::ForLoop* loop) { - return gpu_lower->caIndexMap().toFusion(loop->iter_domain()); - }); + [&](kir::ForLoop* loop) { return loop->iter_domain(); }); auto non_unswitched_inputs = IterVisitor::getInputsTo(non_unswited_leaf_domains); @@ -100,26 +88,23 @@ std::unordered_set getNonUnswitchedRootDomains( non_unswitched_concrete_root_domains, non_unswitched_concrete_root_domains.end()), [&](auto root_dom) { - return gpu_lower->caIndexMap().getConcreteMappedID(root_dom); + return GpuLower::current()->caIndexMap().getConcreteMappedID(root_dom); }); return non_unswitched_concrete_root_domains; } bool isFullyUnswitched( - kir::IterDomain* loop_id, + IterDomain* loop_id, const std::unordered_set& non_unswitched_root_domains) { - const auto gpu_lower = GpuLower::current(); - - auto root_vals = - IterVisitor::getInputsTo({gpu_lower->caIndexMap().toFusion(loop_id)}); + auto root_vals = IterVisitor::getInputsTo({loop_id}); auto root_domains = ir_utils::filterByType(root_vals); return std::none_of( root_domains.begin(), root_domains.end(), [&](auto root_dom) { auto concrete_root_dom = - gpu_lower->caIndexMap().getConcreteMappedID(root_dom); + GpuLower::current()->caIndexMap().getConcreteMappedID(root_dom); return non_unswitched_root_domains.count(concrete_root_dom) > 0; }); } @@ -131,12 +116,10 @@ std::unordered_map< ParallelizedDomainPredicate::PredicateInfo, TypeHash> ParallelizedDomainPredicate::getPredicateMap( - const kir::Expr* expr, + const Expr* expr, const std::vector& loops, kir::ForLoop* unswitched_loop) { const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - auto output_tvs = ir_utils::getTvs(expr->outputs()); if (output_tvs.empty()) { @@ -167,7 +150,7 @@ ParallelizedDomainPredicate::getPredicateMap( } auto loop_id = loop->iter_domain(); - auto loop_ptype = loop_id->parallelType(); + auto loop_ptype = loop_id->getParallelType(); // Not necessary to add a predicate if the paralle type is exact if (!isParallelTypeThread(loop_ptype) || @@ -193,7 +176,7 @@ ParallelizedDomainPredicate::getPredicateMap( continue; } - kir::IterDomain* tv_id = *it; + IterDomain* tv_id = *it; // If the corresponding domain is a broadcast, it's not really used. if (tv_id->isBroadcast()) { @@ -203,9 +186,9 @@ ParallelizedDomainPredicate::getPredicateMap( // If it's a root domain, it should be covered by the root // predicates, so no extra predicate is required. if (std::find( - tv->domain()->rootDomain().begin(), - tv->domain()->rootDomain().end(), - tv_id) != tv->domain()->rootDomain().end()) { + tv->domain()->getRootDomain().begin(), + tv->domain()->getRootDomain().end(), + tv_id) != tv->domain()->getRootDomain().end()) { continue; } @@ -218,26 +201,24 @@ ParallelizedDomainPredicate::getPredicateMap( return map; } -kir::Bool* ParallelizedDomainPredicate::getPredicate( - const kir::Expr* expr, +Bool* ParallelizedDomainPredicate::getPredicate( + const Expr* expr, const std::vector& loops) { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - auto pred_map = getPredicateMap(expr, loops); - kir::Val* pred = ir_builder.trueVal(); + Val* pred = GpuLower::current()->kernel()->trueVal(); for (auto pt : kParallelTypeThreads) { auto pred_info_it = pred_map.find(pt); if (pred_info_it != pred_map.end()) { const auto& pred_info = pred_info_it->second; auto tid_pred = pred_info.getPredicate(); - pred = ir_builder.andExpr(pred, tid_pred); + pred = SimplifyingIrBuilder::andExpr(pred, tid_pred); } } if (pred) { - return pred->as(); + return pred->as(); } else { return nullptr; } @@ -256,61 +237,55 @@ UnswitchPredicateKey::UnswitchPredicateKey() // concrete domains are used to uniquely collect all necessary // unswitch predicates. UnswitchPredicateKey::UnswitchPredicateKey( - IterDomain* predicated_concrete_id, - const ReferenceTensor& reference) + IterDomain* predicated_consumer_id, + TensorView* consumer_tv, + IterDomain* predicated_concrete_id) : predicated_concrete_id_(predicated_concrete_id) { // Initialize the parallelized domain map for (auto pt : kParallelTypeThreads) { parallel_concrete_ids_.insert({pt, nullptr}); } - // The id parameter is a concrete domain. Needs to find the - // corresponding reference domain to find leaf domains that are - // parallelized. - IterDomain* predicated_ref_id = - reference.concrete_to_id.at(predicated_concrete_id_); - TensorDomain* ref_td = reference.domain; - - std::vector all_parallelized_ref_leaf_ids; + std::vector all_parallelized_consumer_leaf_ids; std::copy_if( - ref_td->domain().begin(), - ref_td->domain().end(), - std::back_inserter(all_parallelized_ref_leaf_ids), + consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end(), + std::back_inserter(all_parallelized_consumer_leaf_ids), [](IterDomain* x) { return isParallelTypeThread(x->getParallelType()); }); - // If the reference is not parallelized at all, no need to + // If the consumer domais are not parallelized at all, no need to // differentiate keys based on how the predicated id is parallelized - if (all_parallelized_ref_leaf_ids.empty()) { + if (all_parallelized_consumer_leaf_ids.empty()) { return; } - // All domains that are parallelized descendants of predicated_ref_id - auto all_parallelized_ref_ids = DependencyCheck::getAllValsBetween( - {predicated_ref_id}, all_parallelized_ref_leaf_ids); + // All domains that are parallelized descendants of predicated_consumer_id + auto all_parallelized_consumer_ids = DependencyCheck::getAllValsBetween( + {predicated_consumer_id}, all_parallelized_consumer_leaf_ids); // Just pick leaf domains - std::vector parallelized_ref_leaf_ids; + std::vector parallelized_consumer_leaf_ids; std::copy_if( - ref_td->domain().begin(), - ref_td->domain().end(), - std::back_inserter(parallelized_ref_leaf_ids), + consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end(), + std::back_inserter(parallelized_consumer_leaf_ids), [&](IterDomain* x) { return std::find( - all_parallelized_ref_ids.begin(), - all_parallelized_ref_ids.end(), - x) != all_parallelized_ref_ids.end(); + all_parallelized_consumer_ids.begin(), + all_parallelized_consumer_ids.end(), + x) != all_parallelized_consumer_ids.end(); }); - if (parallelized_ref_leaf_ids.empty()) { - // None of the parallelized leaf domains are derived from predicated_ref_id + if (parallelized_consumer_leaf_ids.empty()) { + // None of the parallelized leaf domains are derived from + // predicated_consumer_id return; } // Find the corresponding concrete id for each parallel type - for (auto ref_leaf : parallelized_ref_leaf_ids) { - auto pt = ref_leaf->getParallelType(); - auto it = reference.id_to_concrete.find(ref_leaf); - TORCH_INTERNAL_ASSERT(it != reference.id_to_concrete.end()); - auto concrete_leaf = it->second; + for (auto consumer_leaf : parallelized_consumer_leaf_ids) { + auto pt = consumer_leaf->getParallelType(); + auto concrete_leaf = + GpuLower::current()->caIndexMap().getConcreteMappedID(consumer_leaf); parallel_concrete_ids_.at(pt) = concrete_leaf; } } @@ -344,19 +319,18 @@ std::size_t UnswitchPredicateKeyHash::operator()( return h; }; -kir::Bool* PredicateCompute::getInlinePredicate( - const kir::Expr* expr, +Bool* PredicateCompute::getInlinePredicate( + const Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, PredicateType pred_type) { FUSER_PERF_SCOPE("GpuLower::Lower::getInlinePredicate"); const auto gpu_lower = GpuLower::current(); - kir::SimplifyingIrBuilder ir_builder(gpu_lower->kernel()); // If outputs are registers, no need to predicate for threads if (isOutputLocal(expr)) { - thread_pred = ir_builder.trueVal(); + thread_pred = gpu_lower->kernel()->trueVal(); } if (loops.empty()) { @@ -364,8 +338,8 @@ kir::Bool* PredicateCompute::getInlinePredicate( return thread_pred; } - auto out_tv = ir_utils::getTVOutput(expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output"); + auto out_tv = ir_utils::getTvOutput(expr); + TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); if (gpu_lower->predicateElimination().canOmitPredicate(expr)) { return thread_pred; @@ -376,7 +350,7 @@ kir::Bool* PredicateCompute::getInlinePredicate( out_tv, loops, nullptr, pred_type == PredicateType::Padding) .first; - std::vector preds; + std::vector preds; // When pred_type is ReductionWrite, filter out predicates for // reduction axes. For blockReduce, this is necessary when reduction @@ -388,7 +362,7 @@ kir::Bool* PredicateCompute::getInlinePredicate( bool non_zero_start_found = false; for (const auto& pred_info : pred_info_vec) { if (pred_type == PredicateType::ReductionWrite) { - const auto& consumer_ids = pred_info.consumerIds(); + const auto& consumer_ids = pred_info.rootIds(); bool pred_for_reduction_axis = false; for (auto consumer_id : consumer_ids) { if (consumer_id->isReduction()) { @@ -404,21 +378,15 @@ kir::Bool* PredicateCompute::getInlinePredicate( continue; } } - for (auto pred : pred_info.startPredicates()) { - TORCH_INTERNAL_ASSERT(pred != nullptr); - preds.push_back(pred); - } - for (auto pred : pred_info.stopPredicates()) { - TORCH_INTERNAL_ASSERT(pred != nullptr); - preds.push_back(pred); - } + preds.push_back(pred_info.startPredicate()); + preds.push_back(pred_info.stopPredicate()); } // When generating a predicate for blockReduce writes and not for // gridReduce, if all reduction axes start with zero, we can just // use the same predicate for reads. nullptr is returned then. if (pred_type == PredicateType::ReductionWrite && !non_zero_start_found && - !out_tv->fuserTv()->domain()->hasGridReduction()) { + !out_tv->domain()->hasGridReduction()) { return nullptr; } @@ -433,35 +401,33 @@ kir::Bool* PredicateCompute::getInlinePredicate( } if (preds.empty()) { - return ir_builder.trueVal(); + return GpuLower::current()->kernel()->trueVal(); } - kir::Val* cond = preds[0]; + Val* cond = preds[0]; for (const auto i : c10::irange(1, preds.size())) { - cond = ir_builder.andExpr(cond, preds[i]); + cond = SimplifyingIrBuilder::andExpr(cond, preds[i]); } - return cond->as(); + return cond->as(); } -kir::Bool* UnswitchPredicate::get( +Bool* UnswitchPredicate::get( const std::vector& outer_loops, kir::ForLoop* unrolled_loop) { FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::get"); - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); - UnswitchPredicate up(outer_loops, unrolled_loop); - kir::Val* unswitch_pred = ir_builder.trueVal(); + Val* unswitch_pred = GpuLower::current()->kernel()->trueVal(); for (auto pred : up.predicates_) { - unswitch_pred = ir_builder.andExpr(unswitch_pred, pred); + unswitch_pred = SimplifyingIrBuilder::andExpr(unswitch_pred, pred); } - return unswitch_pred->as(); + return unswitch_pred->as(); } -void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { +void UnswitchPredicate::predicateOn(Expr* tv_expr) { FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::predicateOn"); if (for_loops_.empty()) { @@ -469,14 +435,12 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { } const auto gpu_lower = GpuLower::current(); - kir::IrBuilder ir_builder(gpu_lower->kernel()); - if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr)) { return; } - auto out_tv = ir_utils::getTVOutput(tv_expr); - TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output"); + auto out_tv = ir_utils::getTvOutput(tv_expr); + TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); auto ref_pred_info = Index::getReferenceRootPredicates( out_tv, for_loops_, unrolled_loop_, false); @@ -491,10 +455,8 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { // predicates are generated in the finalize function. for (const auto& pred_info : ref_pred_info.first) { - if (pred_info.startPredicates().empty() && - pred_info.stopPredicates().empty()) { - continue; - } + TORCH_INTERNAL_ASSERT(pred_info.startPredicate() != nullptr); + TORCH_INTERNAL_ASSERT(pred_info.stopPredicate() != nullptr); const auto& root_ids = pred_info.rootIds(); @@ -505,13 +467,14 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { bool first_key_set = false; for (auto root_id : root_ids) { - auto kir_root_id = gpu_lower->lowerValue(root_id)->as(); + auto concrete_root_id = + gpu_lower->caIndexMap().getConcreteMappedID(root_id); - if (kir_root_id->isBroadcast()) { + if (root_id->isBroadcast()) { continue; } - UnswitchPredicateKey key(root_id, reference); + UnswitchPredicateKey key(root_id, out_tv, concrete_root_id); auto inserted = predicated_keys_.insert(key).second; add_pred = add_pred || inserted; @@ -573,14 +536,14 @@ void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) { // start and stop offsets. if (merged_pred_it != pending_predicates_.end()) { mergeUnswitchPredicateOffsets( - pred_info.startPredicates(), - pred_info.startOffsets(), + pred_info.startPredicate(), + pred_info.startOffset(), merged_pred_it->start, true); mergeUnswitchPredicateOffsets( - pred_info.stopPredicates(), - pred_info.stopOffsets(), + pred_info.stopPredicate(), + pred_info.stopOffset(), merged_pred_it->stop, false); } @@ -613,7 +576,7 @@ void UnswitchPredicate::openLoop(kir::ForLoop* fl) { for_loops_.push_back(fl); for (auto expr : fl->body().exprs()) { - if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) { + if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) { predicateOn(expr); } else if (auto ite = dynamic_cast(expr)) { openIte(ite); @@ -630,7 +593,7 @@ void UnswitchPredicate::openIte(kir::IfThenElse* ite) { // only expand the ite thenBody for (auto expr : ite->thenBody().exprs()) { - if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) { + if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) { predicateOn(expr); } else if (auto ite = dynamic_cast(expr)) { openIte(ite); @@ -641,7 +604,6 @@ void UnswitchPredicate::openIte(kir::IfThenElse* ite) { } void UnswitchPredicate::finalize() { - kir::SimplifyingIrBuilder ir_builder(GpuLower::current()->kernel()); for (const auto& merged_pred : pending_predicates_) { const auto& start_info = merged_pred.start; if (start_info.static_pred) { @@ -661,12 +623,10 @@ void UnswitchPredicate::finalize() { } void UnswitchPredicate::mergeUnswitchPredicateOffsets( - const std::vector& predicates, - const std::vector& offsets, + Bool* predicate, + Val* offset, MergedPredicates::Info& merged_predicate_info, bool is_start) { - TORCH_INTERNAL_ASSERT(predicates.size() == offsets.size()); - auto is_more_restrictive = [&is_start](int64_t new_val, int64_t current_val) { if (is_start) { return new_val < current_val; @@ -675,25 +635,21 @@ void UnswitchPredicate::mergeUnswitchPredicateOffsets( } }; - for (const auto i : c10::irange(predicates.size())) { - auto pred = predicates.at(i); - auto offset = offsets.at(i); - auto offset_int = dynamic_cast(offset); - // If it's a static predicate, replace the current one if it's - // more restrictive. If it's dynamic, just adds it to the dynamic - // predicate list. - if (offset_int && offset_int->isConst()) { - auto offset_const = offset_int->value().value(); - auto& static_pred = merged_predicate_info.static_pred; - auto& static_offset = merged_predicate_info.static_offset; - if (static_pred == nullptr || - is_more_restrictive(offset_const, static_offset)) { - static_pred = pred; - static_offset = offset_const; - } - } else { - merged_predicate_info.dynamic_preds.push_back(pred); + auto offset_int = dynamic_cast(offset); + // If it's a static predicate, replace the current one if it's + // more restrictive. If it's dynamic, just adds it to the dynamic + // predicate list. + if (offset_int && offset_int->isConst()) { + auto offset_const = offset_int->value().value(); + auto& static_pred = merged_predicate_info.static_pred; + auto& static_offset = merged_predicate_info.static_offset; + if (static_pred == nullptr || + is_more_restrictive(offset_const, static_offset)) { + static_pred = predicate; + static_offset = offset_const; } + } else { + merged_predicate_info.dynamic_preds.push_back(predicate); } } diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.h b/torch/csrc/jit/codegen/cuda/predicate_compute.h index 989bffb3bd1..c6412671e43 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.h +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.h @@ -16,10 +16,10 @@ class PredicateCompute { // ignore_internal_syncthread_ops will prevent creation of predicates on // block/grid broadcast/reduce as these have syncthread calls within them // so all threads need to execute the function. - static kir::Bool* getInlinePredicate( - const kir::Expr* expr, + static Bool* getInlinePredicate( + const Expr* expr, const std::vector& loops, - kir::Bool* thread_pred, + Bool* thread_pred, PredicateType pred_type); }; @@ -40,31 +40,31 @@ class ParallelizedDomainPredicate { explicit PredicateInfo(ParallelType pt) : pt_(pt) {} //! Adds a domain that is parallized by the same paralell type - bool addDomain(kir::IterDomain* id); + bool addDomain(IterDomain* id); - const std::vector& ids() const { + const std::vector& ids() const { return ids_; } //! Generates a predicate Val from predicate information - kir::Bool* getPredicate() const; + Bool* getPredicate() const; private: ParallelType pt_; //! Domains parallelized by the same parallel type - std::vector ids_; + std::vector ids_; }; //! Returns a predicate Val for parallelied domains of an expression. - static kir::Bool* getPredicate( - const kir::Expr* expr, + static Bool* getPredicate( + const Expr* expr, const std::vector& loops); //! Returns predicate information for parallelied domains of an //! expression. static std::unordered_map getPredicateMap( - const kir::Expr* expr, + const Expr* expr, const std::vector& loops, kir::ForLoop* unswitched_loop = nullptr); }; @@ -80,8 +80,9 @@ class UnswitchPredicateKey { UnswitchPredicateKey(); UnswitchPredicateKey( - IterDomain* predicated_concrete_id, - const ReferenceTensor& reference); + IterDomain* predicated_consumer_id, + TensorView* consumer_tv, + IterDomain* predicated_concrete_id); bool operator==(const UnswitchPredicateKey& other) const { return predicated_concrete_id_ == other.predicated_concrete_id_ && @@ -121,7 +122,7 @@ struct UnswitchPredicateKeyHash { class TORCH_CUDA_CU_API UnswitchPredicate { public: - static kir::Bool* get( + static Bool* get( const std::vector& outer_loops, kir::ForLoop* unrolled_loop); @@ -132,11 +133,11 @@ class TORCH_CUDA_CU_API UnswitchPredicate { struct Info { //! Most restrictive static predicate. Nullptr if no static //! predicate found. - kir::Bool* static_pred = nullptr; + Bool* static_pred = nullptr; //! The offset value of static_pred int64_t static_offset = 0; //! List of dynamic predicates. - std::vector dynamic_preds; + std::vector dynamic_preds; }; UnswitchPredicateKey predicate_key; Info start; @@ -147,7 +148,7 @@ class TORCH_CUDA_CU_API UnswitchPredicate { std::vector outer_loops, kir::ForLoop* unrolled_loop); - void predicateOn(kir::Expr*); + void predicateOn(Expr*); void openLoop(kir::ForLoop*); @@ -160,8 +161,8 @@ class TORCH_CUDA_CU_API UnswitchPredicate { //! static, only pick the most restrictive one, e.g., the one with the //! minimum offset for the start predication. void mergeUnswitchPredicateOffsets( - const std::vector& predicates, - const std::vector& offsets, + Bool* predicate, + Val* offset, MergedPredicates::Info& merged_predicate_info, bool is_start); @@ -181,7 +182,7 @@ class TORCH_CUDA_CU_API UnswitchPredicate { parallelized_dom_predicates_; //! The predicates that have been generated. - std::vector predicates_; + std::vector predicates_; std::vector for_loops_; diff --git a/torch/csrc/jit/codegen/cuda/reference_tensor.h b/torch/csrc/jit/codegen/cuda/reference_tensor.h index 2220831dc09..07c83bb6ed7 100644 --- a/torch/csrc/jit/codegen/cuda/reference_tensor.h +++ b/torch/csrc/jit/codegen/cuda/reference_tensor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index ddb92371baa..b48c6b00b3a 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -196,7 +196,7 @@ UnmappableReductionDomains::UnmappableReductionDomains() { namespace { -//! Find all domains that a given domain is depeendent on +//! Find all domains that a given domain is dependent on class FindInputDomains : BackwardVisitor { private: FindInputDomains(TensorView* tv, const IterDomain* id) @@ -661,6 +661,58 @@ void ComputeAtRootDomainMapBuilder::setMapped( root_map_.eq_set_.join(producer, consumer); } +void ComputeAtRootDomainMapBuilder::setInvalid( + const DomainKey& key1, + const DomainKey& key2) { + invalid_mappings_.emplace_back(key1, key2); +} + +bool ComputeAtRootDomainMapBuilder::isInvalid( + const std::vector& domains) const { + // First, collect all invalid mappings for each of the keys in domains + DomainKeyMap invalid_key_map; + for (const auto& key : domains) { + DomainKeySet invalid_keys; + for (const auto& invalid_pair : invalid_mappings_) { + if (root_map_.canMap(key, invalid_pair.first)) { + invalid_keys.insert(invalid_pair.second); + } else if (root_map_.canMap(key, invalid_pair.second)) { + invalid_keys.insert(invalid_pair.first); + } + } + invalid_key_map.emplace(key, invalid_keys); + } + + // Next, check if any pair is invalid to map. + const auto num_keys = domains.size(); + for (const auto i : c10::irange(num_keys)) { + const auto& key_i = domains[i]; + // If no invalid keys found for key_i, it can be skipped. + const auto invalid_key_map_it = invalid_key_map.find(key_i); + if (invalid_key_map_it == invalid_key_map.end()) { + continue; + } + + // Set of keys that are invalid to be mapped with key_i. + const DomainKeySet& invalid_keys_for_i = invalid_key_map_it->second; + + // If any other key in domains is identified mappable with any of + // the keys in this set, the mapping with key_i is invalid. + for (const auto j : c10::irange(i + 1, num_keys)) { + const auto& key_j = domains[j]; + if (std::any_of( + invalid_keys_for_i.begin(), + invalid_keys_for_i.end(), + [&](const auto& invalid_key_for_i) { + return root_map_.canMap(key_j, invalid_key_for_i); + })) { + return true; + } + } + } + return false; +} + void ComputeAtRootDomainMapBuilder::setMaybeMapped( const TensorDomain* producer_td, const IterDomain* producer_id, @@ -853,9 +905,11 @@ bool ComputeAtRootDomainMapBuilder::mapAllConsumers( // All entries in key_set must be equivalent with each other. TORCH_INTERNAL_ASSERT(consumer_set.size() > 0); bool consistent = safeToMap(consumer_set); - if (consistent) { - for (const auto pending_consumer : consumer_set) { + for (const auto pending_consumer : consumer_set) { + if (consistent) { setMapped(producer_key, pending_consumer); + } else { + setInvalid(producer_key, pending_consumer); } } // This entry should never be used again, so remove it. @@ -931,6 +985,10 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { !map_through_reduction_) { return false; } + // Make sure mapping these domains won't cause any invalid mapping + if (isInvalid(unique_domains)) { + return false; + } return true; } diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index 23ada0fb120..5156dc604f1 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -5,7 +5,7 @@ #include #include -#include +#include namespace torch { namespace jit { @@ -110,7 +110,7 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { const TensorView* consumer_tv_ = nullptr; }; -std::string toString(const PairwiseRootDomainMap& root_map); +TORCH_CUDA_CU_API std::string toString(const PairwiseRootDomainMap& root_map); //! Represents an iteration domain of a TensorDomain. Only used for //! root domain mapping. @@ -206,7 +206,7 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor { //! This will create mappings between i0, i2 and i4. class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { friend class ComputeAtRootDomainMapBuilder; - friend std::string toString(const ComputeAtRootDomainMap&); + friend TORCH_CUDA_CU_API std::string toString(const ComputeAtRootDomainMap&); public: //! Builds a mapping table by analyzing the current @@ -327,7 +327,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { std::unordered_set window_axes_; }; -std::string toString(const ComputeAtRootDomainMap& root_map); +TORCH_CUDA_CU_API std::string toString(const ComputeAtRootDomainMap& root_map); //! Create a DisjointSet of root IterDomains by traversing the //! current fusion entirely. IterDomains that can be mapped each @@ -347,6 +347,12 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder //! Set a pair of producer-consumer domain keys as mappable void setMapped(const DomainKey& producer, const DomainKey& consumer); + //! Records two domains are invalid to map + void setInvalid(const DomainKey& key1, const DomainKey& key2); + + //! Check if no pair of domains is invalid to map + bool isInvalid(const std::vector& domains) const; + //! Track a pair of producer-consumer domains as potentially mappable. Inserts //! entries into pending_map_, but does not add anything into the root_map_ //! (added when handle is called on a TensorView). Maybe mapped will, however, @@ -415,10 +421,13 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder private: ComputeAtRootDomainMap& root_map_; - //! Keep track of what we want to try and map. Set in attemptToProveId. + //! Keep track of what we want to try and map DomainKeyMap pending_map_; std::unordered_set visited_; + //! Helper class to find invalid mappings due to reductions UnmappableReductionDomains incompatible_domains_; + //! Running vector of domain pairs that are invalid to map + std::vector> invalid_mappings_; //! Disable UnmappableReductions check, should //! always be false for compute_at use cases diff --git a/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu b/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu index ed366132689..fcbc98e7818 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu @@ -41,10 +41,8 @@ __device__ void sync() { // threads have incremented the counter. while (local_sync_counter < next && old < local_sync_counter) { #if __CUDA_ARCH__ >= 700 - __nanosleep(backoff); -#else - // __nanosleep is not available for sm < 70 - assert(false); + // __nanosleep only available on compute capability 7.0 or higher + __nanosleep(backoff); // avoids busy waiting #endif if (backoff < backoff_max) { backoff *= 2; diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu index a75d0d5904a..83382f4704c 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu @@ -69,7 +69,7 @@ template < typename Func> __device__ void gridReduceLastBlock( T& out, - const T* in, + const volatile T* in, const nvfuser_index_t grid_reduction_segment_size, // Number of reductions across // grid reduce dimensions diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu index 0ccb07142aa..a134bd81c2d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu @@ -54,10 +54,8 @@ __device__ void sync(int64_t& semaphore, const uint64_t& segment_size) { // Put a sleep here so we have some breaks in probing the global // semaphore, giving a better chance for other warps/blocks to catch up. #if __CUDA_ARCH__ >= 700 - __nanosleep(200); -#else - // __nanosleep is not available for sm < 70 - assert(false); + // __nanosleep only available on compute capability 7.0 or higher + __nanosleep(200); // avoids busy waiting #endif } } diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu index 61dccb4dff2..02fd8bf8777 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -279,3 +279,19 @@ template <> double pow(double a, double b) { return ::pow(a, b); } + +float pow(float a, int b) { + return pow(a, (float)b); +} + +double pow(double a, int b) { + return pow(a, (double)b); +} + +float pow(float a, int64_t b) { + return pow(a, (float)b); +} + +double pow(double a, int64_t b) { + return pow(a, (double)b); +} diff --git a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu index aab51a8f158..ac4f2069b3b 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tensor.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tensor.cu @@ -19,3 +19,13 @@ struct Tensor { T* data; }; + +// Specialization for 0-dim case that's easy to pass in a CPU based tensor. +template +struct CpuScalarTensor { + __device__ T& operator[](int) { + return data; + }; + + T data; +}; diff --git a/torch/csrc/jit/codegen/cuda/runtime/welford.cu b/torch/csrc/jit/codegen/cuda/runtime/welford.cu index 07d848c55f2..c3b09d82b74 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/welford.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/welford.cu @@ -8,8 +8,8 @@ __inline__ __device__ void welfordCombine( T& a_avg, T& a_M2, TN& a_N, - const T& b_avg, - const T& b_M2, + const T b_avg, + const T b_M2, TN b_N) { if (b_N == 0) { return; @@ -183,9 +183,9 @@ __device__ void gridWelfordLastBlock( T& out_avg, T& out_M2, TN& out_N, - const T* in_avg, - const T* in_M2, - const TN* in_N, + const volatile T* in_avg, + const volatile T* in_M2, + const volatile TN* in_N, const nvfuser_index_t grid_reduction_segment_size, // Number of reductions across // grid reduce dimensions @@ -345,9 +345,9 @@ __device__ void gridWelford( out_avg, out_M2, out_N, - (T*)work_buf_avg, - (T*)work_buf_M2, - (TN*)work_buf_N, + work_buf_avg, + work_buf_M2, + work_buf_N, grid_reduction_segment_size, block_reduction_segment_size, shared_buf_avg, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index b856d83ac92..8aa3081fcc6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -43,6 +43,9 @@ ReductionParams innerPersistentHeuristic( // Set some targets for parallelization const int64_t n_elems = total_reduction_numel * total_iteration_numel; + const int64_t outer_reduction_numel = + total_reduction_numel / inner_most_dimension_numel; + // WARNING: At some point we may want to generate heuristics for another // device that is not the current device. const int64_t device_max_threads_per_multiprocessor = @@ -228,7 +231,7 @@ ReductionParams innerPersistentHeuristic( bdimz = std::min( std::min( std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1), - ceilDiv(total_reduction_numel, inner_most_dimension_numel)), + outer_reduction_numel), scheduler_utils::z_block_limit); // If 3D doesn't fill out the threads, adjust to add to bdimy @@ -251,15 +254,13 @@ ReductionParams innerPersistentHeuristic( bdimz = std::min( std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1), - ceilDiv(total_reduction_numel, inner_most_dimension_numel)); + outer_reduction_numel); bdimy = std::min( std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1), max_multi_reduction_factor); } - godim = ceilDiv(total_iteration_numel, bdimy); - bool vectorize = false; // Move unrolling factor into vectorization upto vectorization limit. @@ -275,8 +276,7 @@ ReductionParams innerPersistentHeuristic( if (inner_reduction_unroll_factor < max_unroll) { outer_reduction_unroll_factor = std::min( ceilDiv(max_unroll, inner_reduction_unroll_factor), - ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), bdimz)); + ceilDiv(outer_reduction_numel, bdimz)); } godim = ceilDiv(total_iteration_numel, bdimy); @@ -304,9 +304,8 @@ ReductionParams innerPersistentHeuristic( while (outer_reduction_unroll_factor < max_unroll && batches_per_block_outer_reduction >= 2) { outer_reduction_unroll_factor *= 2; - batches_per_block_outer_reduction = roundUpPow2Or8(ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), - bdimz * outer_reduction_unroll_factor)); + batches_per_block_outer_reduction = roundUpPow2Or8( + ceilDiv(outer_reduction_numel, bdimz * outer_reduction_unroll_factor)); } // If we haven't gotten to the max_unroll case, try to take it out of the @@ -334,7 +333,7 @@ ReductionParams innerPersistentHeuristic( inner_most_dimension_numel, inner_reduction_unroll_factor * batches_per_block_inner_reduction); bdimz = ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), + outer_reduction_numel, outer_reduction_unroll_factor * batches_per_block_outer_reduction); // Try moving persistent buffer factors into threads until we have too many @@ -368,9 +367,8 @@ ReductionParams innerPersistentHeuristic( batches_per_block_outer_reduction = roundUpPow2Or8(batches_per_block_outer_reduction / 2); bdimz = ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), + outer_reduction_numel, batches_per_block_outer_reduction * outer_reduction_unroll_factor); - continue; } break; @@ -410,13 +408,18 @@ ReductionParams innerPersistentHeuristic( pad_bdimx = pad_bdimx && bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel; + // Will be used once supporting inter-block persistence + int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimy = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimz = LaunchParams::UNINITIALIZED_VAL; + ReductionParams rparams; rparams.persistent_kernel = true; rparams.fastest_dim = true; // Inner reduction domain - rparams.cross_block_inner_reduce = true; + rparams.cross_block_inner_reduction = true; rparams.block_dim_inner_reduction = ParallelType::TIDx; rparams.pad_inner_reduction_to_warp = pad_bdimx; rparams.batches_per_block_inner_reduction = batches_per_block_inner_reduction; @@ -432,8 +435,15 @@ ReductionParams innerPersistentHeuristic( if (rparams.multiple_reds_per_blk) { rparams.block_dim_iter_dom = ParallelType::TIDy; } - rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = godim > scheduler_utils::x_grid_limit; + + if (godim > 1) { + rparams.grid_dim_iter_dom = ParallelType::BIDx; + if (godim > scheduler_utils::x_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimx = scheduler_utils::x_grid_limit; + } + } + if (iter_unroll_factor > 1) { rparams.unroll_iter_dom = true; rparams.unroll_factor_iter_dom = iter_unroll_factor; @@ -445,15 +455,15 @@ ReductionParams innerPersistentHeuristic( rparams.batches_per_block_outer_reduction = batches_per_block_outer_reduction; rparams.block_dim_outer_reduction = ParallelType::TIDz; - rparams.cross_block_outer_reduce = true; + rparams.cross_block_outer_reduction = true; rparams.unroll_outer_reduction = outer_reduction_unroll_factor > 1; rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor; } rparams.lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, + gdimx, + gdimy, + gdimz, LaunchParams::UNINITIALIZED_VAL, bdimy, LaunchParams::UNINITIALIZED_VAL); @@ -697,8 +707,8 @@ ReductionParams OuterPersistentHeuristic( rparams.persistent_kernel = true; rparams.fastest_dim = false; - rparams.cross_block_inner_reduce = true; - rparams.cross_grid_inner_reduce = false; + rparams.cross_block_inner_reduction = true; + rparams.cross_grid_inner_reduction = false; rparams.multiple_reds_per_blk = bdimx > 1; if (rparams.multiple_reds_per_blk) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index fb478f1110f..fb465b287e6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -391,6 +391,12 @@ class DomainMap { return nullptr; } + static bool hasReferenceTensorView(Fusion* fusion) { + FusionGuard fg(fusion); + DomainMap domain_map(fusion); + return domain_map.findReferenceTensorView() != nullptr; + } + private: // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input. @@ -417,7 +423,8 @@ class DomainMap { // Get concrete IDs for input root or rfactor domain std::unordered_set in_concrete_ids; for (auto in_id : input_tv->getMaybeRFactorDomain()) { - if (!in_id->isBroadcast() && !in_id->isReduction()) { + if (!ca_index_map_.getConcreteMappedID(in_id)->isBroadcast() && + !in_id->isReduction()) { in_concrete_ids.insert(ca_index_map_.getConcreteMappedID(in_id)); } } @@ -491,6 +498,10 @@ class DomainMap { } // namespace +bool hasReferenceTensorView(Fusion* fusion) { + return DomainMap::hasReferenceTensorView(fusion); +} + // TODO: Inline intermediate operations (avoid inlining unrolled/vectorized // input/output caches) void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { @@ -503,7 +514,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // maybe has_reduction for scheduling should be done on a per output tensor // basis. TORCH_INTERNAL_ASSERT( - !fusion->hasReduction(), "This scheduler only handles pointwise ops."); + ir_utils::getReductionOps(fusion).empty(), + "This scheduler only handles pointwise ops."); // For intermediate outputs, apply cache_fork auto outs = fusion->outputs(); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h index cb626556579..57b77bb20cc 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h @@ -31,6 +31,11 @@ TORCH_CUDA_CU_API LaunchParams schedulePointwise( Fusion* fusion, const at::ArrayRef& runtime_inputs); +//! Utility for canSchedule interface to check if this fusion has +//! a fully broadcasted reference tensor, which is necessary for +//! the pointwise scheduler. +bool hasReferenceTensorView(Fusion* fusion); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index b0d4f12b921..088968b0890 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -334,9 +334,9 @@ ReductionParams innerReductionHeuristic( ReductionParams rparams; rparams.fastest_dim = true; - rparams.cross_block_inner_reduce = true; + rparams.cross_block_inner_reduction = true; rparams.block_dim_inner_reduction = ParallelType::TIDx; - rparams.cross_grid_inner_reduce = gridim > 1; + rparams.cross_grid_inner_reduction = gridim > 1; rparams.multiple_reds_per_blk = bdimy > 1; bool pad_bdimx = bdimx > 16 && bdimx * bdimy < @@ -359,7 +359,9 @@ ReductionParams innerReductionHeuristic( rparams.vectorize_inner_reduction = vectorize; } - rparams.block_dim_iter_dom = ParallelType::TIDy; + if (rparams.multiple_reds_per_blk) { + rparams.block_dim_iter_dom = ParallelType::TIDy; + } if (iter_unroll_factor > 1) { rparams.unroll_iter_dom = true; rparams.unroll_factor_iter_dom = iter_unroll_factor; @@ -368,10 +370,10 @@ ReductionParams innerReductionHeuristic( rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel; // Outer reduction domain if (rparams.schedule_3D) { - rparams.cross_grid_outer_reduce = grodim > 1; + rparams.cross_grid_outer_reduction = grodim > 1; if (bdimz > 1) { rparams.block_dim_outer_reduction = ParallelType::TIDz; - rparams.cross_block_outer_reduce = true; + rparams.cross_block_outer_reduction = true; } rparams.unroll_outer_reduction = outer_reduction_unroll_factor > 1; rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor; @@ -385,39 +387,40 @@ ReductionParams innerReductionHeuristic( // gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in // case it's larger than gdimy can hold, as not doing so can thrash the cache. - if (rparams.cross_grid_inner_reduce) { + if (rparams.cross_grid_inner_reduction) { rparams.grid_dim_inner_reduction = ParallelType::BIDx; - gdimx = gridim; - rparams.split_grid_dim_inner_reduction = - gdimx > scheduler_utils::x_grid_limit; + rparams.split_grid_dim_inner_reduction = true; + gdimx = std::min(gridim, scheduler_utils::x_grid_limit); rparams.grid_dim_iter_dom = ParallelType::BIDy; - gdimy = godim; - rparams.split_grid_dim_iter_dom = gdimy > scheduler_utils::y_grid_limit; + if (godim > scheduler_utils::y_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimy = std::min(godim, scheduler_utils::y_grid_limit); + } } else { - gdimx = godim; rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; + if (gdimx > scheduler_utils::x_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimx = godim; + } } - if (rparams.cross_grid_outer_reduce) { - if (rparams.cross_block_inner_reduce) { - gdimz = grodim; + if (rparams.cross_grid_outer_reduction) { + if (rparams.cross_block_inner_reduction) { rparams.grid_dim_outer_reduction = ParallelType::BIDz; + gdimz = std::min(grodim, scheduler_utils::z_grid_limit); + rparams.split_grid_dim_outer_reduction = true; } else { - gdimy = grodim; rparams.grid_dim_outer_reduction = ParallelType::BIDy; + gdimy = std::min(grodim, scheduler_utils::y_grid_limit); + rparams.split_grid_dim_outer_reduction = true; } } rparams.lparams = LaunchParams( - rparams.grid_dim_iter_dom == ParallelType::BIDx - ? LaunchParams::UNINITIALIZED_VAL - : gdimx, - rparams.grid_dim_iter_dom == ParallelType::BIDy - ? LaunchParams::UNINITIALIZED_VAL - : gdimy, + gdimx, + gdimy, gdimz, bdimx, bdimy > 1 ? bdimy : LaunchParams::UNINITIALIZED_VAL, @@ -441,12 +444,13 @@ ReductionParams innerReductionHeuristic( // schedule if (rparams.schedule_3D) { if (rparams.multiple_reds_per_blk && - (rparams.cross_grid_inner_reduce || rparams.cross_grid_outer_reduce)) { + (rparams.cross_grid_inner_reduction || + rparams.cross_grid_outer_reduction)) { if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { std::cerr << "\n===== UNSUPPORTED REDUCTION HEURISTIC ========\n"; std::cerr << rparams.multiple_reds_per_blk << ", " << rparams.unroll_inner_reduction << ", " - << rparams.cross_grid_inner_reduce << std::endl; + << rparams.cross_grid_inner_reduction << std::endl; } return innerReductionHeuristic( total_reduction_numel, @@ -534,9 +538,9 @@ ReductionParams OuterReductionHeuristic( // domain for this // Blocks for reductions - int64_t gdimy = 1; + int64_t grdim = 1; // Blocks for outputs - int64_t gdimx = 1; + int64_t gidim = 1; // Threads for reduction int64_t bdimy = 1; @@ -597,11 +601,11 @@ ReductionParams OuterReductionHeuristic( std::min(max_unroll, ceilDiv(total_reduction_numel, bdimy)); // Go cross grid - gdimy = ceilDiv( + grdim = ceilDiv( ceilDiv(total_reduction_numel, bdimy * inner_reduction_unroll_factor), (int64_t)4); - gdimx = ceilDiv(total_iteration_numel, bdimx * iter_unroll_factor); + gidim = ceilDiv(total_iteration_numel, bdimx * iter_unroll_factor); // Clang tidy constexpr int64_t kEight = 8; @@ -611,13 +615,13 @@ ReductionParams OuterReductionHeuristic( if (ceilDiv(total_reduction_numel, bdimy * inner_reduction_unroll_factor) >= kThirtyTwo) { // Many reduction elements, go cross grid - int64_t min_gdimy = 1; - if (gdimy > 1) { + int64_t min_grdim = 1; + if (grdim > 1) { // already cross grid, don't go below target or what was already set - min_gdimy = std::min(gdimy, ceilDiv(target_blocks, gdimx)); + min_grdim = std::min(grdim, ceilDiv(target_blocks, gidim)); } - gdimy = std::max( - min_gdimy, + grdim = std::max( + min_grdim, ceilDiv( ceilDiv( total_reduction_numel, bdimy * inner_reduction_unroll_factor), @@ -625,33 +629,33 @@ ReductionParams OuterReductionHeuristic( // Don't go too far above number of threads in a block since that's how many // threads are available to do final reduction iteration // This is good! - gdimy = std::min(gdimy, bdimx * bdimy * kEight); + grdim = std::min(grdim, bdimx * bdimy * kEight); } // Try to do some cleanup of ragged waves on device if ( // If we have less than 8 waves of blocks - gdimy * gdimx < device_multiprocessor_count * kEight && + grdim * gidim < device_multiprocessor_count * kEight && // And we don't have an even divisible number of blocks - (gdimy * gdimx) % device_multiprocessor_count != 0 && + (grdim * gidim) % device_multiprocessor_count != 0 && // And we have more than one wave - gdimy * gdimx > device_multiprocessor_count) { + grdim * gidim > device_multiprocessor_count) { // round waves down auto waves = - std::max((gdimx * gdimy) / device_multiprocessor_count, (int64_t)1); - auto new_gdimy = - std::max((waves * device_multiprocessor_count) / gdimx, (int64_t)1); + std::max((gidim * grdim) / device_multiprocessor_count, (int64_t)1); + auto new_grdim = + std::max((waves * device_multiprocessor_count) / gidim, (int64_t)1); if ( - // If difference is less than 25% of the original gdimy - (new_gdimy - gdimy) * 4 < gdimy && + // If difference is less than 25% of the original grdim + (new_grdim - grdim) * 4 < grdim && // and difference is less than 25% of the original number of blocks - ((new_gdimy * gdimx) - (gdimy * gdimx)) * 4 < gdimy * gdimx) { - gdimy = new_gdimy; + ((new_grdim * gidim) - (grdim * gidim)) * 4 < grdim * gidim) { + grdim = new_grdim; } } // Cannot unroll with cross grid reductions - if (gdimy > 1 && iter_unroll_factor > 1) { + if (grdim > 1 && iter_unroll_factor > 1) { // Readjust the thread bindings, ideally we would repeat the block setup // without considering iter domain unrolling, but for now will simplify bdimx = std::min(max_threads_in_block, bdimx * iter_unroll_factor); @@ -664,10 +668,18 @@ ReductionParams OuterReductionHeuristic( iter_unroll_factor = 1; } + int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimy = LaunchParams::UNINITIALIZED_VAL; + ReductionParams rparams; // cross grid implies cross block - rparams.cross_block_inner_reduce = bdimy > 1 || gdimy > 1; - rparams.cross_grid_inner_reduce = gdimy > 1; + rparams.cross_block_inner_reduction = bdimy > 1 || grdim > 1; + rparams.cross_grid_inner_reduction = grdim > 1; + if (rparams.cross_grid_inner_reduction) { + rparams.split_grid_dim_inner_reduction = true; + rparams.grid_dim_inner_reduction = ParallelType::BIDy; + gdimy = std::min(grdim, scheduler_utils::y_grid_limit); + } rparams.multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1; if (rparams.multiple_reds_per_blk) { @@ -675,15 +687,12 @@ ReductionParams OuterReductionHeuristic( } rparams.grid_dim_iter_dom = ParallelType::BIDx; - rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit; - - if (rparams.cross_grid_inner_reduce) { - rparams.grid_dim_inner_reduction = ParallelType::BIDy; - rparams.split_grid_dim_inner_reduction = - gdimy > scheduler_utils::y_grid_limit; + if (gidim > scheduler_utils::x_grid_limit) { + rparams.split_grid_dim_iter_dom = true; + gdimx = scheduler_utils::x_grid_limit; } - if (rparams.cross_block_inner_reduce) { + if (rparams.cross_block_inner_reduction) { if (rparams.block_dim_iter_dom == ParallelType::TIDx) { rparams.block_dim_inner_reduction = ParallelType::TIDy; } else { @@ -702,7 +711,7 @@ ReductionParams OuterReductionHeuristic( } rparams.lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, + gdimx, gdimy, LaunchParams::UNINITIALIZED_VAL, rparams.multiple_reds_per_blk ? bdimx : bdimy, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h index aafae3f09ff..a710e0c0ed8 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h @@ -31,9 +31,9 @@ class ReductionParams { // Inner Reduction Domain: // Reduce across the block? - bool cross_block_inner_reduce = false; + bool cross_block_inner_reduction = false; // Reduce across the grid? - bool cross_grid_inner_reduce = false; + bool cross_grid_inner_reduction = false; // Inner reduction unroll/vectorize bool unroll_inner_reduction = false; // Unrolling factor @@ -81,9 +81,9 @@ class ReductionParams { // Outer Reduction Domain if 3D Scheduled: // Reduce across the block? - bool cross_block_outer_reduce = false; + bool cross_block_outer_reduction = false; // Reduce across the grid? - bool cross_grid_outer_reduce = false; + bool cross_grid_outer_reduction = false; // Split grid dim for iteration axis in case it's too large for cuda bool split_grid_dim_outer_reduction = false; // Register persistent buffer size in outer dimension @@ -113,8 +113,8 @@ class ReductionParams { other.persistent_kernel == persistent_kernel && other.project_persistent_buffers == project_persistent_buffers && other.schedule_3D == schedule_3D && - other.cross_block_inner_reduce == cross_block_inner_reduce && - other.cross_grid_inner_reduce == cross_grid_inner_reduce && + other.cross_block_inner_reduction == cross_block_inner_reduction && + other.cross_grid_inner_reduction == cross_grid_inner_reduction && other.unroll_inner_reduction == unroll_inner_reduction && other.unroll_factor_inner_reduction == unroll_factor_inner_reduction && other.vectorize_inner_reduction == vectorize_inner_reduction && @@ -128,8 +128,8 @@ class ReductionParams { other.unroll_factor_iter_dom == unroll_factor_iter_dom && other.vectorize_iter_dom == vectorize_iter_dom && other.split_grid_dim_iter_dom == split_grid_dim_iter_dom && - other.cross_block_outer_reduce == cross_block_outer_reduce && - other.cross_grid_outer_reduce == cross_grid_outer_reduce && + other.cross_block_outer_reduction == cross_block_outer_reduction && + other.cross_grid_outer_reduction == cross_grid_outer_reduction && other.unroll_outer_reduction == unroll_outer_reduction && other.unroll_factor_outer_reduction == unroll_factor_outer_reduction && other.split_grid_dim_outer_reduction == @@ -153,10 +153,10 @@ class ReductionParams { if (schedule_3D) { ss << "3D Schedule\n" << "Outer Reduction: "; - if (cross_block_outer_reduce) { + if (cross_block_outer_reduction) { ss << "cross block - " << block_dim_outer_reduction << " / "; } - if (cross_grid_outer_reduce) { + if (cross_grid_outer_reduction) { ss << "cross grid - " << grid_dim_outer_reduction << " / "; ss << (split_grid_dim_outer_reduction ? "split grid dim / " : ""); } @@ -189,18 +189,18 @@ class ReductionParams { ss << "\nInner Reduction Domain: "; - if (cross_block_inner_reduce) { + if (cross_block_inner_reduction) { ss << "cross block - " << block_dim_inner_reduction << " / "; ss << (pad_inner_reduction_to_warp ? " pad to warp / " : ""); } - if (cross_grid_inner_reduce) { + if (cross_grid_inner_reduction) { ss << "cross grid - " << grid_dim_inner_reduction << " / "; ss << (split_grid_dim_inner_reduction ? "split grid dim / " : ""); } if (batches_per_block_inner_reduction > 1 || persistent_kernel) { ss << "persistent batch - " << batches_per_block_inner_reduction << " / "; } - ss << (cross_grid_inner_reduce && split_grid_dim_inner_reduction + ss << (cross_grid_inner_reduction && split_grid_dim_inner_reduction ? "split grid dimension / " : "") << (vectorize_inner_reduction ? "vectorize / " : "") @@ -225,8 +225,8 @@ class ReductionParamsHash { static_cast(rp.persistent_kernel) << (bits - 2) ^ static_cast(rp.project_persistent_buffers) << (bits - 3) ^ static_cast(rp.schedule_3D) << (bits - 4) ^ - static_cast(rp.cross_block_inner_reduce) << (bits - 5) ^ - static_cast(rp.cross_grid_inner_reduce) << (bits - 6) ^ + static_cast(rp.cross_block_inner_reduction) << (bits - 5) ^ + static_cast(rp.cross_grid_inner_reduction) << (bits - 6) ^ static_cast(rp.unroll_inner_reduction) << (bits - 7) ^ static_cast(rp.unroll_factor_inner_reduction) ^ static_cast(rp.vectorize_inner_reduction) << (bits - 8) ^ @@ -239,8 +239,8 @@ class ReductionParamsHash { static_cast(rp.unroll_factor_iter_dom) ^ static_cast(rp.vectorize_iter_dom) << (bits - 14) ^ static_cast(rp.split_grid_dim_iter_dom) << (bits - 15) ^ - static_cast(rp.cross_block_outer_reduce) << (bits - 16) ^ - static_cast(rp.cross_grid_outer_reduce) << (bits - 17) ^ + static_cast(rp.cross_block_outer_reduction) << (bits - 16) ^ + static_cast(rp.cross_grid_outer_reduction) << (bits - 17) ^ static_cast(rp.split_grid_dim_outer_reduction) << (bits - 18) ^ static_cast(rp.batches_per_block_outer_reduction) << (bits - 19); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index 3850fa9638b..57988d8d994 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -43,257 +43,170 @@ TensorView* scheduleReductionTV( !(!rparams.fastest_dim && rparams.vectorize_inner_reduction), "Cannot vectorize reduction domain on outer reductions."); - TORCH_INTERNAL_ASSERT( - !(rparams.cross_grid_inner_reduce && rparams.persistent_kernel), - "Grid reductions not implemented yet for persistent kernels."); - TORCH_INTERNAL_ASSERT( !(rparams.multiple_reds_per_blk && !has_iter_axis), "Multiple reductions requires an iter domain, but one wasn't found."); TORCH_INTERNAL_ASSERT( - !(rparams.cross_grid_inner_reduce && rparams.unroll_iter_dom), + !(rparams.cross_grid_inner_reduction && rparams.unroll_iter_dom), "Unrolling on iter domain not supported with cross grid reductions."); TORCH_INTERNAL_ASSERT( !(rparams.unroll_iter_dom && !has_iter_axis), "Unrolling on iter domain requires an iter domain."); - // Inner reduction axis: - if (rparams.unroll_inner_reduction) { - if (rparams.persistent_kernel) { - if (rparams.vectorize_inner_reduction) { - reduction_tv->split( - inner_reduce_axis, - rparams.batches_per_block_inner_reduction, - false); - reduction_tv->split( - inner_reduce_axis + 1, rparams.unroll_factor_inner_reduction); - - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(ParallelType::Vectorize); - } else { - reduction_tv->split( - inner_reduce_axis, - rparams.batches_per_block_inner_reduction * - rparams.unroll_factor_inner_reduction, - false); - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 2)->padToMultipleOfWarp(); - } - } - } else { - if (isParallelTypeThread(rparams.block_dim_inner_reduction)) { - if (rparams.vectorize_inner_reduction) { - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_inner_reduction)); - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(ParallelType::Vectorize); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } - } else { - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_inner_reduction)); - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(inner_reduce_axis + 2) - ->parallelize(rparams.block_dim_inner_reduction); - - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 2)->padToMultipleOfWarp(); - } - } - } else { - // Inner reduction is not parallelized, but is unrolled or vectorized: - reduction_tv->split( - inner_reduce_axis, rparams.unroll_factor_inner_reduction); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize( - rparams.vectorize_inner_reduction ? ParallelType::Vectorize - : ParallelType::Unroll); - } + auto vectorize = [&reduction_tv](int axis, int factor) { + reduction_tv->split(axis, factor); + reduction_tv->axis(axis + 1)->parallelize(ParallelType::Vectorize); + }; + + auto inner_parallel = [&reduction_tv](int axis, ParallelType ptype) { + reduction_tv->split(axis, NamedScalar::getParallelDim(ptype)); + reduction_tv->axis(axis + 1)->parallelize(ptype); + }; + + auto inner_unswitch = [&reduction_tv](int axis) { + reduction_tv->split(axis, 1); + reduction_tv->axis(axis + 1)->parallelize(ParallelType::Unswitch); + }; + + auto inner_unroll = [&reduction_tv](int axis, int factor) { + reduction_tv->split(axis, factor); + reduction_tv->axis(axis + 1)->parallelize(ParallelType::Unroll); + }; + + auto outer_parallel = [&reduction_tv](int axis, ParallelType ptype) { + reduction_tv->split(axis, NamedScalar::getParallelDim(ptype), false); + reduction_tv->axis(axis)->parallelize(ptype); + }; + + auto outer_unswitch = [&reduction_tv](int axis) { + reduction_tv->split(axis, 1, false); + reduction_tv->axis(axis)->parallelize(ParallelType::Unswitch); + }; + + auto outer_unroll = [&reduction_tv](int axis, int factor) { + reduction_tv->split(axis, factor, false); + reduction_tv->axis(axis)->parallelize(ParallelType::Unroll); + }; + + if (rparams.persistent_kernel) { + // Persistent Format: + // [Grid Split, persistent buffer, unswitch, unroll, thread dim, vectorize] + if (rparams.vectorize_inner_reduction) { + vectorize(inner_reduce_axis, rparams.unroll_factor_inner_reduction); + } + auto outer_i = inner_reduce_axis; + if (rparams.cross_grid_inner_reduction) { + outer_parallel(outer_i++, rparams.grid_dim_inner_reduction); + } + + reduction_tv->split( + outer_i++, rparams.batches_per_block_inner_reduction, false); + + outer_unswitch(outer_i++); + + if (!rparams.vectorize_inner_reduction && rparams.unroll_inner_reduction) { + outer_unroll(outer_i++, rparams.unroll_factor_inner_reduction); + } + + reduction_tv->axis(outer_i)->parallelize(rparams.block_dim_inner_reduction); + + if (rparams.pad_inner_reduction_to_warp) { + reduction_tv->axis(outer_i)->padToMultipleOfWarp(); } - // Unswitch axis which gives us finer control on allocations with - // unrolling - reduction_tv->split(inner_reduce_axis, 1); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unswitch); } else { - // Parallelize reduction axis, don't unroll it0 - if (rparams.cross_block_inner_reduce) { - if (rparams.persistent_kernel) { - reduction_tv->split( - inner_reduce_axis, - rparams.batches_per_block_inner_reduction, - false); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } - } else { - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_inner_reduction)); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(rparams.block_dim_inner_reduction); - if (rparams.pad_inner_reduction_to_warp) { - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - } + // Non-persistent format: + // [Grid Split, Remainder, unswitch, unroll, thread dim, vectorize] + if (rparams.vectorize_inner_reduction) { + vectorize(inner_reduce_axis, rparams.unroll_factor_inner_reduction); + } + + if (rparams.cross_block_inner_reduction) { + inner_parallel(inner_reduce_axis, rparams.block_dim_inner_reduction); + if (rparams.pad_inner_reduction_to_warp) { + reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); } - } else { - // No parallelization on reduction dim, fake an unswitch axis for - // rfactor - reduction_tv->split(inner_reduce_axis, 1); - reduction_tv->axis(inner_reduce_axis + 1) - ->parallelize(ParallelType::Unswitch); } - } - if (rparams.cross_grid_inner_reduce) { - reduction_tv->split( - inner_reduce_axis, - NamedScalar::getParallelDim(rparams.grid_dim_inner_reduction), - false); - reduction_tv->axis(inner_reduce_axis) - ->parallelize(rparams.grid_dim_inner_reduction); + if (!rparams.vectorize_inner_reduction && rparams.unroll_inner_reduction) { + inner_unroll(inner_reduce_axis, rparams.unroll_factor_inner_reduction); + } + + inner_unswitch(inner_reduce_axis); + if (rparams.cross_grid_inner_reduction) { + if (rparams.split_grid_dim_inner_reduction) { + outer_parallel(inner_reduce_axis, rparams.grid_dim_inner_reduction); + } else { + reduction_tv->axis(inner_reduce_axis) + ->parallelize(rparams.grid_dim_inner_reduction); + } + } } // Outer reduction axis if (rparams.schedule_3D) { - if (rparams.unroll_outer_reduction) { - if (rparams.persistent_kernel) { - reduction_tv->split( - outer_reduce_axis, - rparams.batches_per_block_outer_reduction * - rparams.unroll_factor_outer_reduction, - false); - reduction_tv->split( - outer_reduce_axis, rparams.unroll_factor_outer_reduction); - - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(outer_reduce_axis + 2) - ->parallelize(rparams.block_dim_outer_reduction); - } else { - if (isParallelTypeThread(rparams.block_dim_outer_reduction)) { - reduction_tv->split( - outer_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_outer_reduction)); - reduction_tv->split( - outer_reduce_axis, rparams.unroll_factor_outer_reduction); - - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - reduction_tv->axis(outer_reduce_axis + 2) - ->parallelize(rparams.block_dim_outer_reduction); + if (rparams.persistent_kernel) { + // Persistent Format: + // [Grid Split, persistent buffer, unroll, thread dim] + auto outer_i = outer_reduce_axis; + if (rparams.cross_grid_outer_reduction) { + outer_parallel(outer_i++, rparams.grid_dim_outer_reduction); + } - } else { - // outer reduction is not parallelized, but is unrolled or vectorized: - reduction_tv->split( - outer_reduce_axis, rparams.unroll_factor_outer_reduction); - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(ParallelType::Unroll); - } + reduction_tv->split( + outer_i++, rparams.batches_per_block_outer_reduction, false); + + if (rparams.unroll_outer_reduction) { + outer_unroll(outer_i++, rparams.unroll_factor_outer_reduction); } + + reduction_tv->axis(outer_i)->parallelize( + rparams.block_dim_outer_reduction); } else { - // Parallelize reduction axis, don't unroll it0 - if (rparams.cross_block_outer_reduce) { - if (rparams.persistent_kernel) { - reduction_tv->split( - outer_reduce_axis, - rparams.batches_per_block_outer_reduction, - false); - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(rparams.block_dim_outer_reduction); - } else { - reduction_tv->split( - outer_reduce_axis, - NamedScalar::getParallelDim(rparams.block_dim_outer_reduction)); - reduction_tv->axis(outer_reduce_axis + 1) - ->parallelize(rparams.block_dim_outer_reduction); - } + // Non-persistent format: + // [Grid Split, Remainder, unroll, thread dim] + if (rparams.cross_block_outer_reduction) { + inner_parallel(outer_reduce_axis, rparams.block_dim_outer_reduction); } - } - if (rparams.cross_grid_outer_reduce) { - reduction_tv->split( - outer_reduce_axis, - NamedScalar::getParallelDim(rparams.grid_dim_outer_reduction), - false); - reduction_tv->axis(outer_reduce_axis) - ->parallelize(rparams.grid_dim_outer_reduction); + if (rparams.unroll_outer_reduction) { + inner_unroll(outer_reduce_axis, rparams.unroll_factor_outer_reduction); + } + + if (rparams.cross_grid_outer_reduction) { + outer_parallel(outer_reduce_axis, rparams.grid_dim_outer_reduction); + } } } // Iteration domain if (has_iter_axis) { + // [Grid Split, unswitch, unroll, thread dim, vectorize] + + if (rparams.vectorize_iter_dom) { + vectorize(iter_axis, rparams.unroll_factor_iter_dom); + } + if (isParallelTypeThread(rparams.block_dim_iter_dom)) { - if (rparams.vectorize_iter_dom) { - reduction_tv->split(iter_axis, rparams.unroll_factor_iter_dom); - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Vectorize); - - reduction_tv->split( - iter_axis, NamedScalar::getParallelDim(rparams.block_dim_iter_dom)); - reduction_tv->axis(iter_axis + 1) - ->parallelize(rparams.block_dim_iter_dom); - } else { - if ((rparams.fastest_dim && rparams.multiple_reds_per_blk) || - !rparams.fastest_dim) { - reduction_tv->split( - iter_axis, - NamedScalar::getParallelDim(rparams.block_dim_iter_dom)); - reduction_tv->axis(iter_axis + 1) - ->parallelize(rparams.block_dim_iter_dom); - } - if (rparams.unroll_iter_dom) { - reduction_tv->split(iter_axis, rparams.unroll_factor_iter_dom); - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Unroll); - } - } - } else if (rparams.unroll_iter_dom) { - // Iteration domain is not parallelized but it is unrolled or vectorized - reduction_tv->split(iter_axis, rparams.unroll_factor_iter_dom); - if (rparams.vectorize_iter_dom) { - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Vectorize); - } else { - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Unroll); - } + inner_parallel(iter_axis, rparams.block_dim_iter_dom); } + + if (!rparams.vectorize_iter_dom && rparams.unroll_iter_dom) { + inner_unroll(iter_axis, rparams.unroll_factor_iter_dom); + } + if (rparams.unroll_iter_dom) { - reduction_tv->split(iter_axis, 1); - reduction_tv->axis(iter_axis + 1)->parallelize(ParallelType::Unswitch); + inner_unswitch(iter_axis); } - if (rparams.fastest_dim && rparams.split_grid_dim_iter_dom) { - reduction_tv->split(iter_axis, scheduler_utils::x_grid_limit); - reduction_tv->axis(iter_axis + 1)->parallelize(rparams.grid_dim_iter_dom); - } else { - reduction_tv->axis(iter_axis)->parallelize(rparams.grid_dim_iter_dom); + if (isParallelTypeThread(rparams.grid_dim_iter_dom)) { + if (rparams.split_grid_dim_iter_dom) { + outer_parallel(iter_axis, rparams.grid_dim_iter_dom); + } else { + reduction_tv->axis(iter_axis)->parallelize(rparams.grid_dim_iter_dom); + } } } @@ -563,6 +476,48 @@ void multiReductionInliner( scheduler_utils::computeWithOutputs( red_tv, pos, ComputeAtMode::BestEffort); } + // For topologies where there may not be paths to all inputs/outputs from + // the reductions, we need to take a similar approach to the unrolled + // version and setup of compute at from inputs->outputs that are not + // inputs/outputs of the reductions. + std::vector compute_to; + std::unordered_set outs_of_reds; + { + auto outs_of_red_vec = ir_utils::outputTvsOf(ref_tvs); + outs_of_reds = std::unordered_set( + outs_of_red_vec.begin(), outs_of_red_vec.end()); + } + for (auto out : ir_utils::filterByType(fusion->outputs())) { + // only terminating outputs + if (out->uses().size()) { + continue; + } + if (outs_of_reds.find(out) != outs_of_reds.end()) { + continue; + } + compute_to.push_back(out); + } + + std::vector compute_from; + std::unordered_set inps_of_reds; + { + auto inps_of_red_vec = ir_utils::inputTvsOf(ref_tvs); + inps_of_reds = std::unordered_set( + inps_of_red_vec.begin(), inps_of_red_vec.end()); + } + for (auto inp : ir_utils::filterByType(fusion->inputs())) { + if (inps_of_reds.find(inp) != inps_of_reds.end()) { + continue; + } + compute_from.push_back(inp); + } + + scheduler_utils::computeAtBetween( + compute_from, + compute_to, + -1, + ComputeAtMode::MostInlined, + mapped_to_trivial_reduction); } } @@ -595,12 +550,6 @@ int idPos(const IterDomain* id) { } inner_most--; - // Reduction and block - if (id->isReduction() && id->isBlockDim()) { - return inner_most; - } - inner_most--; - // Reduction and constant if (id->isReduction() && id->extent()->isConstScalar()) { return inner_most; @@ -614,7 +563,7 @@ int idPos(const IterDomain* id) { inner_most--; // Reduction and thread - if (id->isReduction() && id->isThreadDim()) { + if (id->isReduction() && id->isThread()) { return inner_most; } inner_most--; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 46b574ac6af..4f2982b01f2 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -1,10 +1,12 @@ #include +#include #include #include #include #include #include #include +#include #include #include @@ -38,7 +40,8 @@ class SchedulerTopologyChecker { auto all_vals = fusion->usedMathVals(); std::vector reduction_tvs; for (auto tv : ir_utils::filterByType(all_vals)) { - if (tv->hasReduction() && !fusion->hasInput(tv)) { + if (tv->hasReduction() && + !(fusion == tv->fusion() && tv->isFusionInput())) { reduction_tvs.push_back(tv); } } @@ -355,6 +358,50 @@ class SchedulerTopologyChecker { return true; } }; + +bool isConnectedFusionGraph(Fusion* fusion) { + if (fusion->outputs().empty()) { + // Trivial case interpreted as connected + return true; + } + + // A set of connected components on the fusion graph + DisjointSet component_sets; + + // Iterate through all used exprs + for (auto expr : fusion->exprs()) { + TORCH_INTERNAL_ASSERT( + !expr->inputs().empty(), "unknown expr with zero input"); + + // Each expr joins all its inputs and + // outputs to the same component + auto input0 = expr->inputs()[0]; + for (auto input : expr->inputs()) { + component_sets.join(input0, input); + } + for (auto output : expr->outputs()) { + component_sets.join(input0, output); + } + } + + // Join aliased outputs + for (auto alias_it : fusion->ioAlias()) { + component_sets.join(alias_it.first, alias_it.second); + } + + // Check connected-ness: + // If there is no independent compute flow + // on this fusion graph, all outputs will be + // equivalent/connected to the first output. + auto output0 = fusion->outputs()[0]; + for (auto output : fusion->outputs()) { + if (!component_sets.areEquivalent(output0, output)) { + return false; + } + } + return true; +} + } // namespace SchedulerRuntimeInfo::SchedulerRuntimeInfo( @@ -634,39 +681,10 @@ bool SchedulerEntry::sameAs(const SchedulerEntry* other) { } namespace { -template -inline bool isTrivialReduction(REDUCTION_OP* red) { - auto o_tv = red->out()->template as(); - // Assuming graph unscheduled at this point. - for (auto id : o_tv->getRootDomain()) { - if (id->isReduction() && !id->extent()->isOneInt()) { - return false; - } - } - return true; -} - -template -std::vector findReductionOps(Fusion* fusion) { - std::vector red_ops; - for (auto expr : fusion->exprs()) { - if (auto red = dynamic_cast(expr)) { - if (!isTrivialReduction(red)) { - red_ops.push_back(red); - } - } - } - return red_ops; -} - std::vector findTransposeOps(Fusion* fusion) { - std::vector transpose_ops; - for (auto expr : fusion->exprs()) { - if (auto transpose_op = dynamic_cast(expr)) { - transpose_ops.push_back(transpose_op); - } - } - return transpose_ops; + auto exprs = fusion->exprs(); + auto transpose_ops = ir_utils::filterByType(exprs); + return std::vector(transpose_ops.begin(), transpose_ops.end()); } static bool checkPatternEquivalence( @@ -765,9 +783,8 @@ class ReductionScheduler : public SchedulerEntry { } // Make sure reduction axes are consistent through the fusion - if (findReductionOps(fusion).size() + - findReductionOps(fusion).size() > - 1) { + auto reduction_ops = ir_utils::getReductionOps(fusion); + if (reduction_ops.size() > 1) { // Before examining the reduction axes want to quickly // check the reductions have the same axis width // to avoid building root domain map in easier cases @@ -857,9 +874,16 @@ class PointWiseScheduler : public SchedulerEntry { } static bool canScheduleCompileTime(Fusion* fusion) { - auto red_ops = findReductionOps(fusion); - auto welford_ops = findReductionOps(fusion); - return red_ops.empty() && welford_ops.empty(); + // Currently using the same path as the scheduler + // to eliminate mismatch between canSchedule and + // schedule pointwise. + if (!hasReferenceTensorView(fusion)) { + return false; + } + + auto reduction_ops = ir_utils::getReductionOps(fusion); + auto welford_ops = ir_utils::filterByType(reduction_ops); + return reduction_ops.empty() && welford_ops.empty(); } static bool canScheduleRunTime( @@ -900,6 +924,14 @@ class PersistentKernelScheduler : public SchedulerEntry { } static bool canScheduleCompileTime(Fusion* fusion) { + auto reduction_ops = ir_utils::getReductionOps(fusion); + auto welford_ops = ir_utils::filterByType(reduction_ops); + // For persistent schedule we want welford translated to average and + // standard deviation reductions. + if (welford_ops.begin() != welford_ops.end()) { + return false; + } + auto view_tvs = scheduler_utils::getViewTVs(fusion); if (view_tvs.size() > 0) { return false; @@ -1079,8 +1111,13 @@ bool checkCanSchedule( // since for all current use cases // it has to pass all the compile time checks to create a data cache for this // fusion. - if (!data_cache && !SchedulerType::canScheduleCompileTime(fusion)) { - return false; + if (!data_cache) { + if (!isConnectedFusionGraph(fusion)) { + return false; + } + if (!SchedulerType::canScheduleCompileTime(fusion)) { + return false; + } } return SchedulerType::canScheduleRunTime(fusion, runtime_info, data_cache); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 7ce9addf0cb..90b348236cf 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -287,6 +287,15 @@ class PersistentBufferResolution : public IterVisitor { } if (tv->hasReduction()) { + if (std::any_of( + resolution_points_.begin(), + resolution_points_.end(), + [&tv](TensorView* resolution_point) { + return DependencyCheck::isDependencyOf(resolution_point, tv); + })) { + // If already resolved, don't start a new reduction path. + return; + } on_reduction_path_.emplace(tv); } } @@ -587,7 +596,7 @@ void computeAtBetween( return mapped_to_trivial_reduction.count(id); }); - pos = pos_it == consumer->domain()->domain().end() + auto consumer_pos = pos_it == consumer->domain()->domain().end() ? pos : std::min( (int)std::distance( @@ -596,7 +605,7 @@ void computeAtBetween( (pos < 0 ? pos + (int)consumer->nDims() : pos)); // Assume we don't want to reset computeAt on tensors that have already // performed it. - producer->computeAt(consumer, pos, mode); + producer->computeAt(consumer, consumer_pos, mode); } } } @@ -1038,15 +1047,22 @@ std::vector> cacheAndForkOutputs( } namespace { +// If this is an rfactored reduction domain, actually check the root domain, +// this is because the rfactored reduction tensorview has the vectorized +// dimension, but that means the rfactor domain could have reordered what we +// consider the "inner most" allocated position on it if we consider the rfactor +// dimension. IterDomain* innerMostRootDim(TensorView* tv) { if (tv->nDims() == 0) { return nullptr; } IterDomain* inner_most_id = nullptr; - for (auto it = tv->getMaybeRFactorDomain().rbegin(); - it != tv->getMaybeRFactorDomain().rend(); - it++) { + auto root_domain = tv->hasReduction() && tv->hasRFactor() + ? tv->getRootDomain() + : tv->getMaybeRFactorDomain(); + + for (auto it = root_domain.rbegin(); it != root_domain.rend(); it++) { if ((*it)->isReduction() && tv->isFusionInput()) { continue; } @@ -1084,7 +1100,7 @@ IterDomain* projectIdToRoot( return reference_id; } - auto replay_exprs = ExprSort::getExprs(tv->fusion(), {reference_id}); + auto replay_exprs = StmtSort::getExprs(tv->fusion(), {reference_id}, false); if (replay_exprs.empty()) { return reference_id; } @@ -1193,12 +1209,16 @@ std::unordered_set FindAllMappedDims::from( TensorView* tv, IterDomain* id, bool vectorize_pass) { + auto root_domain = tv->hasReduction() && tv->hasRFactor() + ? tv->getRootDomain() + : tv->getMaybeRFactorDomain(); + TORCH_INTERNAL_ASSERT( std::find_if( - tv->getMaybeRFactorDomain().begin(), - tv->getMaybeRFactorDomain().end(), + root_domain.begin(), + root_domain.end(), [&id](IterDomain* root_id) { return root_id == id; }) != - tv->getMaybeRFactorDomain().end(), + root_domain.end(), "Tried to map out ", id, " from TV ", diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 2bf8967f74e..911bda3da04 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -3,10 +3,13 @@ #include #include #include +#include #include #include #include #include +#include +#include // Cleanup #include @@ -24,8 +27,14 @@ DataType aten_opt_type_map(const c10::optional& scalar_type) { } } // namespace -TensorView::TensorView(TensorDomain* domain, DataType dtype, MemoryType mtype) - : Val(ValType::TensorView, dtype), domain_(domain), memory_type_(mtype) { +TensorView::TensorView( + IrBuilderPasskey passkey, + TensorDomain* domain, + DataType dtype, + MemoryType mtype) + : Val(passkey, ValType::TensorView, dtype), + domain_(domain), + memory_type_(mtype) { // Don't do this after transforms if (domain_->domain() == domain_->getRootDomain()) { // Mark the size-1 axes as broadcast to support implicit broadcast semantic @@ -38,10 +47,15 @@ TensorView::TensorView(TensorDomain* domain, DataType dtype, MemoryType mtype) } } -TensorView::TensorView(const std::shared_ptr& tensor_type) - : Val(ValType::TensorView, - aten_opt_type_map(tensor_type->scalarType()), - false) { +TensorView::TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& tensor_type) + : Val(passkey, + ValType::TensorView, + aten_opt_type_map(tensor_type->scalarType())) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); std::vector sizes; TORCH_CHECK( @@ -51,13 +65,14 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) if (tensor_type->sizes()[i].has_value() && tensor_type->sizes()[i].value() == 1) { // If size is known to be 1, assuem it needs to be broadcasted. - sizes.push_back(new IterDomain( - new Int(0), - new Int(1), + sizes.push_back(IrBuilder::create( + passkey.ir_container_->zeroVal(), + passkey.ir_container_->oneVal(), ParallelType::Serial, IterType::BroadcastWithStride)); } else { - sizes.push_back(new IterDomain(new Int(0), new Int())); + sizes.push_back(IrBuilder::create( + passkey.ir_container_->zeroVal(), IrBuilder::create())); } } @@ -92,8 +107,16 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) } } - domain_ = new TensorDomain(sizes, contig_info); - name_ = fusion_->registerVal(this); + domain_ = IrBuilder::create(sizes, contig_info); +} + +TensorView::TensorView( + IrBuilderPasskey passkey, + const std::shared_ptr& jit_value) + : TensorView(passkey, jit_value->type()->cast()) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); } TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) @@ -102,7 +125,9 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) compute_at_pos_(src->compute_at_pos_), max_producer_pos_(src->max_producer_pos_), memory_type_(src->memory_type_), - swizzle_type_(src->swizzle_type_) { + swizzle_type_(src->swizzle_type_), + is_double_buffered_(src->is_double_buffered_), + cpu_scalar_(src->cpu_scalar_) { for (const auto id : src->axesToSwizzle()) { axes_to_swizzle_.push_back(ir_cloner->clone(id)); } @@ -152,6 +177,18 @@ std::vector::size_type TensorView::nDims() const { return domain()->nDims(); } +// sets cpu_scalar_ value, which is special handling for CPU based zero-dim +// tensors (i.e. CPU Tensors that only have one value). This is only used if +// on an input value, otherwise ignored. This is important as special handling +// because these "scalars" should be type promoted as a tensor, but we want to +// avoid explicit copying of the data, so we want to pass the data value as a +// standard kernel argument value. +void TensorView::setCpuScalar(bool is_cpu_scalar) { + TORCH_INTERNAL_ASSERT( + nDims() == 0, "Only 0-dim tensors can be marked as a cpu scalar."); + cpu_scalar_ = is_cpu_scalar; +} + IterDomain* TensorView::axis(int pos) const { TORCH_INTERNAL_ASSERT( nDims() > 0, "Tried to access an axis in a 0-dim TensorView"); @@ -167,6 +204,9 @@ IterDomain* TensorView::axis(int pos) const { } void TensorView::setComputeAt(unsigned int pos, bool decrease) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); if (pos <= compute_at_pos_ && !decrease) { return; } @@ -182,6 +222,9 @@ void TensorView::setComputeAt(unsigned int pos, bool decrease) { } void TensorView::setMaxProducer(unsigned int pos, bool decrease) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); if (pos <= max_producer_pos_ && !decrease) { return; } @@ -200,6 +243,9 @@ TensorView* TensorView::computeAt( TensorView* consumer, int position, ComputeAtMode mode) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // Make sure this and consumer are not the same tensor, that's illegal TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)"); @@ -228,6 +274,9 @@ TensorView* TensorView::computeWith( TensorView* consumer, int position, ComputeAtMode mode) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // Make sure this and consumer are not the same tensor, that's illegal TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)"); @@ -290,7 +339,7 @@ TensorView* TensorView::split( unsigned int factor, bool inner_split, bool trim_out_of_bounds) { - split(axis, new Int(factor), inner_split, trim_out_of_bounds); + split(axis, IrBuilder::create(factor), inner_split, trim_out_of_bounds); return this; } @@ -336,6 +385,9 @@ TensorView* TensorView::merge(int axis_o, int axis_i) { } TensorView* TensorView::reorder(const std::unordered_map& old2new_) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); TORCH_INTERNAL_ASSERT( !(nDims() == 0 && old2new_.size() > 0), "Tried to reorder a 0-dim TensorView"); @@ -383,6 +435,9 @@ TensorView* TensorView::reorder(const std::unordered_map& old2new_) { TensorView* TensorView::swizzle( SwizzleType type, const std::vector& axes) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); swizzle_type_ = type; // Clear previously set swizzle axes if any @@ -432,6 +487,9 @@ TensorView* TensorView::swizzle( } TensorView* TensorView::rFactor(const std::vector& axes) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // TODO: I think we should do this but // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the // moment. @@ -462,7 +520,8 @@ TensorView* TensorView::rFactor(const std::vector& axes) { auto consumer_domain = domain_pair.second; // This domain will be the consumer, so create the producer - TensorView* producer = new TensorView(producer_domain, getDataType().value()); + TensorView* producer = + IrBuilder::create(producer_domain, getDataType().value()); // Set domain of consumer setDomain(consumer_domain); @@ -470,14 +529,14 @@ TensorView* TensorView::rFactor(const std::vector& axes) { // Setup dependency chain, inserting producer before this op. // Expr* producer_definition = - new ReductionOp( + IrBuilder::create( this_definition->getReductionOpType(), this_definition->init(), producer, this_definition->in()); // Expr* consumer_definition = - new ReductionOp( + IrBuilder::create( this_definition->getReductionOpType(), this_definition->init(), consumer, @@ -489,6 +548,9 @@ TensorView* TensorView::rFactor(const std::vector& axes) { TensorView* TensorView::welfordRfactorHelper( TensorView* tv, const std::vector& axes) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); // Hack: // Semantically we should always keep the outputs of welfordOp scheduled // the same but the user end cannot guarantee that. @@ -520,7 +582,8 @@ TensorView* TensorView::welfordRfactorHelper( std::vector new_contig( tv->domain()->contiguity().begin(), tv->domain()->contiguity().end()); // replace tensor domain of target tv - tv->setDomain(new TensorDomain(tv->getRootDomain(), new_id, new_contig)); + tv->setDomain(IrBuilder::create( + tv->getRootDomain(), new_id, new_contig)); } // Split tensor view into 2 parts @@ -532,7 +595,7 @@ TensorView* TensorView::welfordRfactorHelper( // This domain will be the consumer, so create the producer TensorView* producer = - new TensorView(producer_domain, tv->getDataType().value()); + IrBuilder::create(producer_domain, tv->getDataType().value()); // Set domain of consumer tv->setDomain(consumer_domain); @@ -545,6 +608,9 @@ WelfordResult TensorView::rFactor( TensorView* avg, TensorView* var, TensorView* n) { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView"); FusionGuard fg(fusion()); TORCH_CHECK( @@ -588,7 +654,7 @@ WelfordResult TensorView::rFactor( // Setup dependency chain, inserting producer before this op. // Expr* producer_definition = - new WelfordOp( + IrBuilder::create( producer_avg, producer_var, producer_n, /*out var/avg/count */ @@ -600,7 +666,7 @@ WelfordResult TensorView::rFactor( wop->inN()); // Expr* consumer_definition = - new WelfordOp( + IrBuilder::create( avg, var, n, @@ -615,6 +681,9 @@ WelfordResult TensorView::rFactor( } TensorView* TensorView::cache_before() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); FusionGuard fg(fusion()); TORCH_CHECK( @@ -652,8 +721,10 @@ TensorView* TensorView::cache_before() { // This domain will be the consumer which needs a new domain, so replace the // producers domain with this domain. - TensorView* producer = new TensorView( - new TensorDomain( + TensorView* producer = IrBuilder::create( + container(), + IrBuilder::create( + container(), domain()->getRootDomain(), domain()->getRFactorDomain(), domain()->domain(), @@ -671,8 +742,10 @@ TensorView* TensorView::cache_before() { new_root_domain[i++] = dom->clone(); } - consumer->setDomain(new TensorDomain( - new_root_domain, std::vector(new_root_domain.size(), true))); + consumer->setDomain(IrBuilder::create( + container(), + new_root_domain, + std::vector(new_root_domain.size(), true))); // Insert producer - Cache_Before (CB) - before this TV. // Before: Prev TV -> [Definition Op] -> This TV @@ -684,7 +757,7 @@ TensorView* TensorView::cache_before() { ir_utils::replaceValInExpr(definition(), this, producer); // Expr* producer_uses = - new UnaryOp(UnaryOpType::Set, consumer, producer); + IrBuilder::create(container(), UnaryOpType::Set, consumer, producer); // definition_ is no longer valid // setDefinition(nullptr); @@ -697,6 +770,9 @@ TensorView* TensorView::cache_before() { } TensorView* TensorView::cache_fork() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); FusionGuard fg(fusion()); // Before: [Expr] -> This TV (Global Output) -> [Usage Expr] @@ -704,7 +780,7 @@ TensorView* TensorView::cache_fork() { // (Fork) -> [Set Expr] -> New TV (Global Output) TORCH_CHECK( - fusion()->hasOutput(this) && !this->uses().empty(), + this->isFusionOutput() && !this->uses().empty(), "Error adding cache_fork ", this, " this TensorView must be an output with subsequent uses"); @@ -717,14 +793,16 @@ TensorView* TensorView::cache_fork() { // This domain will be the producer, so create the consumer auto root_domain = TensorDomain::noReductions(getMaybeRFactorDomain()); - TensorView* new_output = new TensorView( - new TensorDomain( + TensorView* new_output = IrBuilder::create( + container(), + IrBuilder::create( + container(), IterDomain::clone(root_domain), std::vector(root_domain.size(), true)), getDataType().value()); // Create write operation from this TV to new output - new UnaryOp(UnaryOpType::Set, new_output, this); + IrBuilder::create(container(), UnaryOpType::Set, new_output, this); // The new TV becomes an output. // New TV has global memory type. @@ -739,13 +817,14 @@ TensorView* TensorView::cache_fork() { } TensorView* TensorView::cache_after() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); FusionGuard fg(fusion()); - const bool kIsFusionInput = fusion()->hasInput(this); - // Get all the uses for this Tensorview TORCH_CHECK( - !fusion()->hasOutput(this), + !isFusionOutput(), "Error adding cache_after ", this, " we restrict using cache_after on an output."); @@ -759,7 +838,7 @@ TensorView* TensorView::cache_after() { // It also did additional transformation when this tensor is an // input and the outputs of its consumers have computeAt. Make sure // we no longer rely on that behavior. - if (kIsFusionInput) { + if (isFusionInput()) { for (const auto& expr : uses()) { for (TensorView* output : ir_utils::filterByType(expr->outputs())) { @@ -782,9 +861,12 @@ TensorView* TensorView::cache_after() { } // This domain will be the producer, so create the consumer - TensorView* consumer = new TensorView( - new TensorDomain( - new_root_domain, std::vector(new_root_domain.size(), true)), + TensorView* consumer = IrBuilder::create( + container(), + IrBuilder::create( + container(), + new_root_domain, + std::vector(new_root_domain.size(), true)), getDataType().value()); // Set domain of producer - No Change @@ -800,14 +882,14 @@ TensorView* TensorView::cache_after() { } // Expr* consumer_definition = - new UnaryOp(UnaryOpType::Set, consumer, producer); + IrBuilder::create(container(), UnaryOpType::Set, consumer, producer); return consumer; } void TensorView::setMemoryType(MemoryType mt) { memory_type_ = mt; - if (fusion()->hasInput(this) || fusion()->hasOutput(this)) { + if (isFusionInput() || isFusionOutput()) { TORCH_INTERNAL_ASSERT( mt == MemoryType::Global, "Tried to set an input or output to the fusion to a non-global memory type."); @@ -832,7 +914,23 @@ void TensorView::clearReductionIterDomains() { } } - setDomain(new TensorDomain(new_root, new_contig)); + setDomain(IrBuilder::create(container(), new_root, new_contig)); +} + +void TensorView::doubleBuffer() { + // Early correctness checking. May miss eventual errors as the + // checks depend on memory types and parallelization, which may not + // be finalized until lowering. + validateDoubleBufferedTensor(this); + is_double_buffered_ = true; +} + +bool TensorView::isEmptyTensor() const { + auto& root_domain = getMaybeRFactorDomain(); + return std::all_of( + root_domain.begin(), root_domain.end(), [](IterDomain* id) { + return id->extent()->isZeroInt(); + }); } TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) { @@ -872,7 +970,8 @@ TensorView* TensorViewBuilder::build() const { std::vector domain(ndims_, nullptr); for (const auto i : c10::irange(ndims_)) { if (shape_.empty() || shape_[i] == -1) { - domain[i] = new IterDomain(new Int(0), new Int()); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create()); } else { TORCH_CHECK( shape_[i] >= 0, @@ -880,19 +979,22 @@ TensorView* TensorViewBuilder::build() const { "For a tensor representing a single scalar use ndims = 0 with no sizes set."); if (shape_[i] == 1) { // If size is known to be 1, assume it needs to be broadcasted. - domain[i] = new IterDomain( - new Int(0), - new Int(1), + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), ParallelType::Serial, IterType::BroadcastWithStride); } else { - domain[i] = new IterDomain(new Int(0), new Int(shape_[i])); + domain[i] = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create(shape_[i])); } } } // Create the final TensorView - return new TensorView(new TensorDomain(domain, contiguity_), dtype_); + return IrBuilder::create( + IrBuilder::create(domain, contiguity_), dtype_); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 54136616268..bae77943b33 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -228,7 +228,7 @@ BestEffortReplay::BestEffortReplay( } // Grab expr history of iter domains in target_domain - std::vector target_exprs = ExprSort::getExprs( + std::vector target_exprs = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector(target_domain.begin(), target_domain.end())); @@ -239,7 +239,7 @@ BestEffortReplay::BestEffortReplay( // replay_domain map. // Map replay domain's IterDomains to the Exprs they're used in - std::vector replay_exprs = ExprSort::getExprs( + std::vector replay_exprs = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector(replay_domain.begin(), replay_domain.end())); @@ -561,7 +561,7 @@ struct ConsumerForwardingInfo { auto consumer_bcast_ids_not_in_producer = consumer_bcast_roots_not_in_producer; - std::vector consumer_history = ExprSort::getExprs( + std::vector consumer_history = StmtSort::getExprs( FusionGuard::getCurFusion(), std::vector( consumer->domain()->domain().begin(), @@ -706,7 +706,7 @@ BestEffortReplay BestEffortReplay::replayCasP( } // Grab all exprs used to make the forwarded compliments - auto compliment_exprs = ExprSort::getExprs( + auto compliment_exprs = StmtSort::getExprs( FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()}); // Figure out if there are any leaves in compliment_exprs that aren't diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index cde502d636e..f1c4ae378b5 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d0d03532cd6..7ea96f74bf1 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -49,23 +50,26 @@ class ReplaySelf : public ReplayTransformations { // Manually replay the split, following the output of the operations. // This is so rfactor ops are replayed correctly. - IterDomain* ido = new IterDomain( - new Int(0), + IterDomain* ido = IrBuilder::create( + s->container(), + s->container()->zeroVal(), s->innerSplit() ? remainder->as() : s->factor(), s->outer()->getParallelType(), s->outer()->getIterType(), s->outer()->isRFactorProduct()); // inner IterDomain - IterDomain* idi = new IterDomain( - new Int(0), + IterDomain* idi = IrBuilder::create( + s->container(), + s->container()->zeroVal(), s->innerSplit() ? s->factor() : remainder->as(), s->inner()->getParallelType(), s->inner()->getIterType(), s->inner()->isRFactorProduct()); // Generate the split node - new Split( + IrBuilder::create( + s->container(), ido, idi, mapped, @@ -112,14 +116,16 @@ class ReplaySelf : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - IterDomain* merged_id = new IterDomain( - new Int(0), + IterDomain* merged_id = IrBuilder::create( + m->container(), + m->container()->zeroVal(), merged_id_size->as(), m->out()->getParallelType(), m->outer()->getIterType(), m->out()->isRFactorProduct()); - new Merge(merged_id, id_outer_mapped, id_inner_mapped); + IrBuilder::create( + m->container(), merged_id, id_outer_mapped, id_inner_mapped); // Remove inputs from the leaf IDs leaf_ids_.erase(id_outer_mapped); @@ -197,7 +203,8 @@ TensorDomain* TransformReplay::fullSelfReplay( "Error during replay, didn't replay an axis."); new_rfactor_domain[i++] = it->second; } - return new TensorDomain( + return IrBuilder::create( + self->container(), new_self_root->getRootDomain(), new_rfactor_domain, new_domain, @@ -205,8 +212,11 @@ TensorDomain* TransformReplay::fullSelfReplay( } } - return new TensorDomain( - new_self_root->getRootDomain(), new_domain, new_self_root->contiguity()); + return IrBuilder::create( + self->container(), + new_self_root->getRootDomain(), + new_domain, + new_self_root->contiguity()); } // Producer could have rfactor axes which consumer may want replayed. We can @@ -407,7 +417,8 @@ std::pair TransformReplay::replayPasC( new_IDs.push_back(id); } } - TensorDomain* replayed = new TensorDomain( + TensorDomain* replayed = IrBuilder::create( + producer->container(), producer->getRootDomain(), producer->getRFactorDomain(), new_IDs, @@ -604,7 +615,8 @@ std::pair TransformReplay::replayCasP( if (used_IDs.find(id) == used_IDs.end()) new_IDs.push_back(id); - TensorDomain* replayed = new TensorDomain( + TensorDomain* replayed = IrBuilder::create( + consumer->container(), consumer->getRootDomain(), consumer->getRFactorDomain(), new_IDs, diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 92898b54ba7..1fd3d110200 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -1,7 +1,7 @@ #pragma once +#include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index 8ac28cf3a2c..5939ffee289 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -52,23 +53,26 @@ class ReplayRFactor : public ReplayTransformations { // Manually replay the split, making reduction = false and rfactor = true // outer IterDomain - IterDomain* ido = new IterDomain( - new Int(0), + IterDomain* ido = IrBuilder::create( + s->container(), + IrBuilder::create(s->container(), 0), s->innerSplit() ? remainder->as() : s->factor(), ParallelType::Serial, rfactor_outer ? IterType::Reduction : IterType::Iteration, true); // broadcast // inner IterDomain - IterDomain* idi = new IterDomain( - new Int(0), + IterDomain* idi = IrBuilder::create( + s->container(), + IrBuilder::create(s->container(), 0), s->innerSplit() ? s->factor() : remainder->as(), ParallelType::Serial, rfactor_inner ? IterType::Reduction : IterType::Iteration, true); // Generate the split node - new Split(ido, idi, mapped, s->factor(), s->innerSplit()); + IrBuilder::create( + s->container(), ido, idi, mapped, s->factor(), s->innerSplit()); // Remove mapped id from leaf IDs leaf_ids_.erase(mapped); @@ -115,14 +119,16 @@ class ReplayRFactor : public ReplayTransformations { Val* merged_id_size = mul(id_outer_mapped->extent(), id_inner_mapped->extent()); - IterDomain* merged_id = new IterDomain( - new Int(0), + IterDomain* merged_id = IrBuilder::create( + m->container(), + IrBuilder::create(m->container(), 0), merged_id_size->as(), ParallelType::Serial, rfactor_output ? IterType::Reduction : IterType::Iteration, true); - new Merge(merged_id, id_outer_mapped, id_inner_mapped); + IrBuilder::create( + m->container(), merged_id, id_outer_mapped, id_inner_mapped); // Remove inputs from the leaf IDs leaf_ids_.erase(id_outer_mapped); @@ -238,7 +244,8 @@ TensorDomain* TransformRFactor::runReplay( for (auto id : orig_td_root) { // If this is an rfactor root, it will be a reduction in this stage if (rfactor_root_axes.find(id) != rfactor_root_axes.end()) { - new_root[i] = new IterDomain( + new_root[i] = IrBuilder::create( + id->container(), id->start(), id->extent(), id->stopOffset(), @@ -248,7 +255,8 @@ TensorDomain* TransformRFactor::runReplay( // If this is not an rfactor root, but a reduction root, it should be // turned into an iteration domain } else if (id->isReduction()) { - new_root[i] = new IterDomain( + new_root[i] = IrBuilder::create( + id->container(), id->start(), id->extent(), id->stopOffset(), @@ -296,7 +304,8 @@ TensorDomain* TransformRFactor::runReplay( if (dom->isRFactorProduct()) rfactor_root.push_back(dom); - return new TensorDomain( + return IrBuilder::create( + orig_td->container(), new_root, rfactor_root, new_domain, @@ -400,8 +409,11 @@ TensorDomain* TransformRFactor::runReplay2( } } - return new TensorDomain( - new_root, new_domain, std::vector(new_root.size(), true)); + return IrBuilder::create( + orig_td->container(), + new_root, + new_domain, + std::vector(new_root.size(), true)); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h index 551f67905b0..593eb287d0b 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/transform_view.cpp b/torch/csrc/jit/codegen/cuda/transform_view.cpp index ea4d188c092..433e34a11eb 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_view.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -44,11 +45,31 @@ class Transform { size_t index() const { return index_; } + + size_t originalIndex() const { + return original_index_; + } + + size_t newIndex() const { + return new_index_; + } + virtual ~Transform() = default; protected: - Transform(size_t index) : index_(index) {} + Transform(const ViewIndexState& state, size_t index) + : index_(index), + original_index_(state.original_view_index), + new_index_(Transform::computeNewIndex(state)) {} + const size_t index_ = 0; + const size_t original_index_ = 0; + const size_t new_index_ = 0; + + static size_t computeNewIndex(const ViewIndexState& state) { + return state.original_view_index - state.trivial_reduction_offset + + state.split_offset - state.merge_offset + state.broadcast_offset; + } }; //! Base class for all view tranformations - Merge, Split, Keep @@ -61,9 +82,11 @@ class ViewTransform : public Transform { std::vector& rfactor_domain) = 0; ~ViewTransform() override = default; + virtual bool isOriginalAxisDynamic() const = 0; + protected: ViewTransform(const ViewIndexState& state) - : Transform(ViewTransform::computeIndex(state)) {} + : Transform(state, ViewTransform::computeIndex(state)) {} static size_t computeIndex(const ViewIndexState& state) { return state.original_view_index - state.trivial_reduction_offset; @@ -71,6 +94,7 @@ class ViewTransform : public Transform { }; namespace { +typedef std::vector Sizes; const size_t kEmptyAxis = 0; const size_t kSingletonAxis = 1; @@ -86,6 +110,10 @@ class MergeTransform final : public ViewTransform { << std::endl; } + bool isOriginalAxisDynamic() const override { + return false; + } + void createRfactorDomain( const std::vector& new_root_domain, std::vector& rfactor_domain) override { @@ -108,14 +136,15 @@ class MergeTransform final : public ViewTransform { auto merged_extent = mul(merged_id->extent(), new_root_domain[index_ + 1]->extent()); - auto new_merged_id = new IterDomain( - new Int(0), + auto new_merged_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), merged_extent, ParallelType::Serial, IterType::Iteration, true); - new Merge(new_merged_id, merged_id, new_root_domain[index_ + 1]); + IrBuilder::create( + new_merged_id, merged_id, new_root_domain[index_ + 1]); rfactor_domain.push_back(new_merged_id); } @@ -140,6 +169,10 @@ class SplitTransform final : public ViewTransform { << " ARG: " << split_factor_ << std::endl; } + bool isOriginalAxisDynamic() const override { + return false; + } + void createRfactorDomain( const std::vector& new_root_domain, std::vector& rfactor_domain) override { @@ -150,7 +183,7 @@ class SplitTransform final : public ViewTransform { "\t Domain Size:\t", new_root_domain.size()); - auto factor = new Int(split_factor_); + auto factor = IrBuilder::create(split_factor_); IterDomain* id = nullptr; if (is_last_axis_rfactor_) { @@ -164,18 +197,22 @@ class SplitTransform final : public ViewTransform { Val* remainder = ceilDiv(id->extent(), factor); // outer loop IterDomain - IterDomain* factor_id = new IterDomain( - new Int(0), factor, id->getParallelType(), id->getIterType(), true); + IterDomain* factor_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), + factor, + id->getParallelType(), + id->getIterType(), + true); // inner loop IterDomain - IterDomain* remainder_id = new IterDomain( - new Int(0), + IterDomain* remainder_id = IrBuilder::create( + FusionGuard::getCurFusion()->zeroVal(), remainder->as(), ParallelType::Serial, IterType::Iteration, true); - new Split(factor_id, remainder_id, id, factor, false); + IrBuilder::create(factor_id, remainder_id, id, factor, false); rfactor_domain.push_back(factor_id); rfactor_domain.push_back(remainder_id); @@ -195,6 +232,10 @@ class KeepTransform final : public ViewTransform { output << "Keep Index: " << index_ << std::endl; } + bool isOriginalAxisDynamic() const override { + return true; + } + void createRfactorDomain( const std::vector& new_root_domain, std::vector& rfactor_domain) override { @@ -214,17 +255,11 @@ class KeepTransform final : public ViewTransform { class BroadcastTransform final : public Transform { public: BroadcastTransform(const ViewIndexState& state) - : Transform(BroadcastTransform::computeIndex(state)) {} + : Transform(state, Transform::computeNewIndex(state)) {} void toString(std::stringstream& output) const override { output << "Bcast Index: " << index_ << std::endl; } - - private: - static size_t computeIndex(const ViewIndexState& state) { - return state.original_view_index - state.trivial_reduction_offset + - state.split_offset - state.merge_offset + state.broadcast_offset; - } }; //! For any implicit broadcast dimensions in the original view, we remove @@ -232,7 +267,7 @@ class BroadcastTransform final : public Transform { class TrivialReductionTransform final : public Transform { public: TrivialReductionTransform(const ViewIndexState& state) - : Transform(TrivialReductionTransform::computeIndex(state)) {} + : Transform(state, TrivialReductionTransform::computeIndex(state)) {} void toString(std::stringstream& output) const override { output << "1-Red Index: " << index_ << std::endl; @@ -249,10 +284,11 @@ class TrivialReductionTransform final : public Transform { class AnalyzeViewTransformation { public: AnalyzeViewTransformation( - const std::vector root_domain, - const std::vector& original_view, - const std::vector& new_view) - : root_domain_(root_domain), + const Sizes& original_view, + const Sizes& new_view, + std::vector root_domain = {}) + : default_implicit_broadcast_(root_domain.empty()), + root_domain_(root_domain), original_view_(original_view), new_view_(new_view), transform_view_(original_view) { @@ -264,6 +300,24 @@ class AnalyzeViewTransformation { TORCH_INTERNAL_ASSERT(kOriginalNumElements == kNewNumElements); } + AnalyzeViewConstraint constraint() { + findTransformation(); + TORCH_INTERNAL_ASSERT( + validate(), + "Analyze View Transformation failed to find valid transformation.\n", + toString()); + std::vector original_constraint( + original_view_.begin(), original_view_.end()); + std::vector new_constraint(new_view_.begin(), new_view_.end()); + for (auto& vt : view_transforms_) { + if (vt->isOriginalAxisDynamic()) { + original_constraint[vt->originalIndex()] = -1; + new_constraint[vt->newIndex()] = -1; + } + } + return {original_constraint, new_constraint}; + } + AnalyzeViewResult run() { findTransformation(); TORCH_INTERNAL_ASSERT( @@ -382,6 +436,15 @@ class AnalyzeViewTransformation { return true; } + bool isImplicitBroadcast(size_t original_view_index) const { + if (default_implicit_broadcast_) { + return original_view_[original_view_index] == 1; + } else { + TORCH_INTERNAL_ASSERT(!root_domain_.empty()); + return root_domain_[original_view_index]->isImplicitBroadcast(); + } + } + //! This utility class merges a fixed set of axes together //! according to some invariant. Implicit broadcast axes cannot be //! merged with standard iterDomains, so they are handled separately @@ -400,8 +463,7 @@ class AnalyzeViewTransformation { bool any_merge = false; for (size_t idx = 0; idx < num_merge_axes_; ++idx) { - if (avt_->root_domain_[state_.original_view_index] - ->isImplicitBroadcast()) { + if (avt_->isImplicitBroadcast(state_.original_view_index)) { avt_->addTrivialReductionTransform(); } else { avt_->addMergeTransform( @@ -603,9 +665,10 @@ class AnalyzeViewTransformation { std::vector> trivial_reduction_transforms_; + bool default_implicit_broadcast_ = true; const std::vector root_domain_; - const std::vector& original_view_; - const std::vector& new_view_; + const Sizes& original_view_; + const Sizes& new_view_; // transform_view is a mutable view and is initialized with the original_view. // It is used to track the current state of the original tensor domain. @@ -622,7 +685,7 @@ class AnalyzeViewTransformation { // If transform size != original size for an axis, then the transformation // uses the last rfactor domain. Otherwise, it is a root domain // transformation. - std::vector transform_view_; + Sizes transform_view_; }; //! Create new TensorDomain with a modified rfactor domain using the specified @@ -644,7 +707,7 @@ TensorDomain* createViewDomain( t->createRfactorDomain(new_root_domain, rfactor_domain); } - return new TensorDomain( + return IrBuilder::create( new_root_domain, rfactor_domain, rfactor_domain, @@ -652,11 +715,19 @@ TensorDomain* createViewDomain( } //! Infer -1 value in new view sizes from original view sizes -std::vector inferNewViewShape( - const std::vector& original_view, +std::pair inferNewViewShape( + const std::vector& original_sizes, const std::vector& new_sizes) { - std::vector new_view(new_sizes.size()); + bool valid_original_sizes = std::all_of( + original_sizes.begin(), original_sizes.end(), [](int64_t dim) { + return dim > 0; + }); + TORCH_INTERNAL_ASSERT(valid_original_sizes); + Sizes original_view(original_sizes.begin(), original_sizes.end()); + Sizes new_view(new_sizes.size()); + + // TODO: refactor int64_t dynamic_index = -1; size_t new_size_num_elements = 1; for (size_t idx = 0; idx < new_sizes.size(); ++idx) { @@ -665,6 +736,7 @@ std::vector inferNewViewShape( dynamic_index == -1, "Only one dimension can by inferred.") dynamic_index = idx; } else { + TORCH_INTERNAL_ASSERT(new_sizes[idx] > 0); new_size_num_elements *= new_sizes[idx]; new_view[idx] = new_sizes[idx]; } @@ -676,7 +748,7 @@ std::vector inferNewViewShape( new_view[dynamic_index] = kNumElements / new_size_num_elements; } - return new_view; + return {original_view, new_view}; } } // namespace @@ -690,22 +762,24 @@ AnalyzeViewResult analyzeView( FUSER_PERF_SCOPE("analyzeView"); TORCH_INTERNAL_ASSERT( tv->getMaybeRFactorDomain().size() == original_sizes.size()); - - bool valid_original_sizes = std::all_of( - original_sizes.begin(), original_sizes.end(), [](int64_t dim) { - return dim > 0; - }); - - TORCH_INTERNAL_ASSERT(valid_original_sizes); - - std::vector original_view( - original_sizes.begin(), original_sizes.end()); - auto new_view = inferNewViewShape(original_view, new_sizes); + auto sizes = inferNewViewShape(original_sizes, new_sizes); AnalyzeViewTransformation analyzer( - tv->getRootDomain(), original_view, new_view); + sizes.first /* original_view */, + sizes.second /* new_view */, + tv->getRootDomain()); return analyzer.run(); } +AnalyzeViewConstraint analyzeViewConstraint( + const std::vector& original_sizes, + const std::vector& new_sizes) { + FUSER_PERF_SCOPE("analyzeViewConstraint"); + auto sizes = inferNewViewShape(original_sizes, new_sizes); + AnalyzeViewTransformation analyzer( + sizes.first /* original_view */, sizes.second /* new_view */); + return analyzer.constraint(); +} + //! Create new TensorDomain with a modified rfactor domain using the specified //! view transformations TensorDomain* transformView( diff --git a/torch/csrc/jit/codegen/cuda/transform_view.h b/torch/csrc/jit/codegen/cuda/transform_view.h index e7473a1b9b4..f8a986048be 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.h +++ b/torch/csrc/jit/codegen/cuda/transform_view.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -40,6 +40,11 @@ struct AnalyzeViewResult { std::vector> transforms; }; +struct AnalyzeViewConstraint { + std::vector original_constraint; + std::vector new_constraint; +}; + // Find the transformations necessary to convert TensorView // from original size to new size. AnalyzeViewResult analyzeView( @@ -47,6 +52,11 @@ AnalyzeViewResult analyzeView( const std::vector& original_sizes, const std::vector& new_sizes); +// Find the constraints derived from the view transformations +AnalyzeViewConstraint analyzeViewConstraint( + const std::vector& original_sizes, + const std::vector& new_sizes); + // Generate a new TensorDomain from the given view transformations. // The original root domain is kept in the new TensorDomain, // but a new rfactor domain is created from the view transformations. diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 3afb1b540b8..e883421eb1e 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -87,6 +87,9 @@ ValType promote_type(const ValType& t1, const ValType& t2) { (t1 == ValType::Scalar || t1 == ValType::NamedScalar)) { return ValType::Scalar; } + if (t1 == ValType::NamedScalar && t2 == ValType::NamedScalar) { + return ValType::Scalar; + } TORCH_CHECK(false, "Expected promotable ValTypes but got: ", t1, " and ", t2); } @@ -107,7 +110,7 @@ static const char* data_type2string(DataType t) { case DataType::Int32: return "int"; case DataType::Null: - return "nullptr"; + return "null_type"; default: break; } @@ -127,6 +130,10 @@ static const char* val_type2string(ValType t) { return "Scalar"; case ValType::NamedScalar: return "NamedScalar"; + case ValType::Predicate: + return "Predicate"; + case ValType::TensorIndex: + return "TensorIndex"; default: TORCH_INTERNAL_ASSERT(false, "No string found for val type."); } @@ -144,12 +151,38 @@ static const char* expr_type2string(ExprType t) { return "ReductionOp"; case ExprType::BroadcastOp: return "BroadcastOp"; + case ExprType::WelfordOp: + return "WelfordOp"; + case ExprType::TransposeOp: + return "TransposeOp"; case ExprType::ShiftOp: return "ShiftOp"; + case ExprType::GatherOp: + return "GatherOp"; + case ExprType::ViewOp: + return "ViewOp"; case ExprType::Split: return "Split"; case ExprType::Merge: return "Merge"; + case ExprType::Allocate: + return "Allocate"; + case ExprType::Sync: + return "Sync"; + case ExprType::InitMagicZero: + return "InitMagicZero"; + case ExprType::UpdateMagicZero: + return "UpdateMagicZero"; + case ExprType::ForLoop: + return "ForLoop"; + case ExprType::IfThenElse: + return "IfThenElse"; + case ExprType::GridReduction: + return "GridReduction"; + case ExprType::GridBroadcast: + return "GridBroadcast"; + case ExprType::GridWelford: + return "GridWelford"; default: TORCH_INTERNAL_ASSERT(false, "No string found for expr type."); } @@ -281,7 +314,6 @@ bool needFloatSuffix(BinaryOpType t) { case BinaryOpType::Atan2: case BinaryOpType::Div: case BinaryOpType::Fmod: - case BinaryOpType::Pow: return true; default: return false; @@ -522,6 +554,7 @@ static const char* supported_casts2string( case supported_switch_pair(DataType::Float, DataType::Int): case supported_switch_pair(DataType::Double, DataType::Int): return "(int64_t)"; + case supported_switch_pair(DataType::Int, DataType::Int32): case supported_switch_pair(DataType::Float, DataType::Int32): case supported_switch_pair(DataType::Double, DataType::Int32): return "(int32_t)"; diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 43aadb62006..ea7e8bd04d3 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include @@ -32,6 +32,8 @@ enum class ValType { TensorView, Scalar, NamedScalar, + Predicate, + TensorIndex, }; // Manual - The user provides the Bool value. Predicate generation is bypassed. @@ -73,6 +75,15 @@ enum class ExprType { ViewOp, Split, Merge, + Allocate, + Sync, + InitMagicZero, + UpdateMagicZero, + ForLoop, + IfThenElse, + GridReduction, + GridBroadcast, + GridWelford, }; enum class UnaryOpType { @@ -257,8 +268,11 @@ std::string stringifyThread(const ParallelType); std::string typePrefix(const DataType); // TODO: ThreadDim should be BlockDim and BlockDim should be GridDim +// Returns if parallel type is TID[x, y, z] TORCH_CUDA_CU_API bool isParallelTypeThreadDim(ParallelType); +// Returns if parallel type is BID[x, y, z] TORCH_CUDA_CU_API bool isParallelTypeBlockDim(ParallelType); +// Returns if parallel type is a grid or block parallelization dimension TORCH_CUDA_CU_API bool isParallelTypeThread(ParallelType); TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType); diff --git a/torch/csrc/jit/codegen/cuda/type_inference.cpp b/torch/csrc/jit/codegen/cuda/type_inference.cpp index 8c7d7d36a06..a8facc6a45b 100644 --- a/torch/csrc/jit/codegen/cuda/type_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/type_inference.cpp @@ -141,6 +141,7 @@ class NaiveTypePropagator { } // binary operations that forward meta info and broadcast shape: case aten::gelu_backward: + case aten::tanh_backward: case aten::mul: case aten::div: case aten::min: @@ -414,19 +415,14 @@ class NaiveTypePropagator { node->output()->setType(out_type->withDim(c10::nullopt)); break; } - /* - // TODO: Enable view in parser by detecting non-alias view operation - case aten::view: - case aten::reshape: { + case prim::unsqueeze_copy: + case prim::squeeze_copy: + case prim::reshape_copy: + case prim::view_copy: { auto out_type = node->input(0)->type()->cast(); - auto size_optional = constant_as>(node->input(1)); - TORCH_INTERNAL_ASSERT( - size_optional.has_value(), "The size parameter is required."); - auto new_size = size_optional->vec(); - node->output()->setType(out_type->withSizes(new_size)); + node->output()->setType(out_type); break; } - */ case aten::type_as: { const auto type0 = getInputTensorType(node, 0); const auto type1 = getInputTensorType(node, 1); diff --git a/torch/csrc/jit/codegen/cuda/type_promotion.cpp b/torch/csrc/jit/codegen/cuda/type_promotion.cpp index 016e8825acf..68a38e67378 100644 --- a/torch/csrc/jit/codegen/cuda/type_promotion.cpp +++ b/torch/csrc/jit/codegen/cuda/type_promotion.cpp @@ -55,13 +55,14 @@ at::native::ResultTypeState updateResultTypeState( TORCH_INTERNAL_ASSERT( !c10::isComplexType(scalar), "NvFuser does not support complex data types."); + at::native::ResultTypeState new_state = in_state; c10::ScalarType current = scalar; if (c10::isFloatingType(scalar)) { current = c10::typeMetaToScalarType(at::get_default_dtype()); } new_state.wrappedResult = - promoteTypesSkipUndefined(in_state.wrappedResult, scalar); + promoteTypesSkipUndefined(in_state.wrappedResult, current); return new_state; } @@ -195,11 +196,16 @@ std::vector promoteValues( Val* optionalCast(DataType dtype, Val* v) { TORCH_INTERNAL_ASSERT(v->getDataType().has_value()); + // Avoid casting Float/Int scalar to any corresponding FloatingPoint/Integral + // type in fusion. Instead, we cast them directly. The exception is Bool, + // which is always casted to the desired type. const bool kSameDtype = v->getDataType().value() == dtype; const bool kIsScalarFloat = !v->isA() && isFloatingPointType(dtype); + const bool kIsScalarInt = !v->isA() && isIntegralType(dtype); if (kSameDtype || - (kIsScalarFloat && isFloatingPointType(v->getDataType().value()))) { + (kIsScalarFloat && isFloatingPointType(v->getDataType().value())) || + (kIsScalarInt && isIntegralType(v->getDataType().value()))) { return v; } else { return castOp(dtype, v); diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index 67c8359b502..127078b45f7 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -143,6 +143,19 @@ void debugPrint(const c10::TensorTypePtr& type) { } #pragma clang diagnostic pop +bool is_cpu_scalar(const at::Tensor& tensor) { + return tensor.device().is_cpu() && tensor.numel() == 1 && tensor.dim() == 0; +} + +bool is_cpu_scalar(const c10::TensorType& tensor_type) { + auto opt_device = tensor_type.device(); + auto opt_dim = tensor_type.dim(); + auto opt_numel = tensor_type.numel(); + return opt_device.has_value() && opt_device.value().is_cpu() && + opt_dim.has_value() && opt_numel.has_value() && opt_dim.value() == 0 && + opt_numel.value() == 1; +} + bool isDebugDumpEnabled(DebugDumpOption option) { const static auto dump_options = parseDebugDumpOptions(); return dump_options.at(option); @@ -158,6 +171,14 @@ bool disableRNGUnrolling() { return disable_rng_unroll ? atoi(disable_rng_unroll) : false; } +std::vector getTensorSizes(TensorTypePtr const& tensor_type) { + TORCH_INTERNAL_ASSERT(tensor_type != nullptr, "Input must be a Tensor."); + auto optional_sizes = tensor_type->sizes().concrete_sizes(); + TORCH_INTERNAL_ASSERT( + optional_sizes.has_value(), "Missing size information for the tensor."); + return optional_sizes.value(); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index f56d0f8d52e..c035cdeae24 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include namespace torch { namespace jit { @@ -10,6 +11,9 @@ namespace cuda { void debugPrint(const c10::TensorTypePtr& type); +bool is_cpu_scalar(const at::Tensor& tensor); +bool is_cpu_scalar(const c10::TensorType& tensor_type); + //! Types of debug print-outs //! //! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable @@ -116,6 +120,8 @@ constexpr unsigned int switch_pair(T t1, T t2) { return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2; } +std::vector getTensorSizes(TensorTypePtr const& tensor_type); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index fa840621682..0129938568c 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -152,11 +152,11 @@ class AliasDb { * this. */ // Copy `existing`s aliasing info to `new_value`, and remove `existing`. - void replaceWithNewValue(Value* existing, Value* new_value); + TORCH_API void replaceWithNewValue(Value* existing, Value* new_value); // Copy `from`s aliasing info to `to`. - void copyValue(Value* from, Value* to); + TORCH_API void copyValue(Value* from, Value* to); // Create a new `value` that does not alias anything else. - void createValue(const Value* value); + TORCH_API void createValue(const Value* value); // Enable more precise treatment of prim::TupleConstruct. void enablePreciseTupleContainerAnalysis(); diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 9ad8934d55b..ca07e63cb61 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1308,6 +1308,10 @@ def forward(self, a) -> MyModule: obj = obj.__original_fn _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) + # some functions are explicitly marked as not supported in script mode + if hasattr(obj, "__script_unsupported"): + raise RuntimeError("TorchScript error: " + obj.__script_unsupported) + _check_directly_compile_overloaded(obj) maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: