Skip to content

Commit

Permalink
Nvfuser code bump 2_1_2022 (#72127)
Browse files Browse the repository at this point in the history
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
  • Loading branch information
jjsjann123 authored and cyyever committed Feb 17, 2022
1 parent 20f8cdf commit 72a72d0
Show file tree
Hide file tree
Showing 180 changed files with 15,144 additions and 11,754 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ namespace c10 {
_(prim, CudaFusionGuard) \
_(prim, FunctionalGraph) \
_(prim, add_optional) \
_(prim, view_copy) \
_(prim, reshape_copy) \
_(prim, squeeze_copy) \
_(prim, unsqueeze_copy) \
_(prim, DifferentiableGraph) \
_(prim, TensorExprGroup) \
_(prim, TensorExprDynamicGroup) \
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/cpp/nvfuser/batch_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -44,8 +45,8 @@ static void setupBatchNorm(Fusion* fusion, DataType dtype) {
bias = castOp(DataType::Float, bias);
}

auto momentum_ptr = new Double(kMomentum);
auto eps_ptr = new Double(kEps);
auto momentum_ptr = IrBuilder::create<Double>(kMomentum);
auto eps_ptr = IrBuilder::create<Double>(kEps);

auto result = batch_norm(
input,
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/batch_norm_backward.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -49,7 +50,7 @@ static void setupBatchNorm_BWD(Fusion* fusion, DataType dtype) {
grad_output = castOp(DataType::Float, grad_output);
}

auto eps_ptr = new Double(kEps);
auto eps_ptr = IrBuilder::create<Double>(kEps);

auto result = batch_norm_backward(
input,
Expand Down
18 changes: 10 additions & 8 deletions benchmarks/cpp/nvfuser/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -36,7 +37,7 @@ static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) {
fusion->addInput(tv1);

// TODO: should be input
auto d16 = new Double(1.0);
auto d16 = IrBuilder::create<Double>(1.0);

if (is_fp16) {
tv0 = castOp(DataType::Float, tv0);
Expand All @@ -47,7 +48,7 @@ static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) {
auto tv3 = add(tv2, tv0);

auto tv10 = softmax(tv3, 3);
auto dropout_tvs = dropout(tv10, new Double(0.9));
auto dropout_tvs = dropout(tv10, IrBuilder::create<Double>(0.9));
auto tv12 = dropout_tvs.mask;
auto tv14 = dropout_tvs.output;

Expand Down Expand Up @@ -83,9 +84,9 @@ static void setupDivMaxSoftmaxDropoutBackward(Fusion* fusion, DataType dtype) {
}

// TODO: should be inputs
auto d32 = new Double(1.0);
auto d32 = IrBuilder::create<Double>(1.0);
// fusion->addInput(d32);
auto d33 = new Double(2.0);
auto d33 = IrBuilder::create<Double>(2.0);
// fusion->addInput(d33);

auto tv4 = mul(tv2, tv3);
Expand Down Expand Up @@ -252,14 +253,15 @@ static void setupBiasDropoutAddLayernormFwd(Fusion* fusion, DataType dtype) {

auto tv5 = broadcast(tv4, {true, true, false});
auto tv6 = add(tv3, tv5);
auto dropout_outs = dropout(tv6, new Double(0.9));
auto dropout_outs = dropout(tv6, IrBuilder::create<Double>(0.9));

auto tv8 = dropout_outs.output;
auto tv10 = dropout_outs.mask;

auto tv11 = add(tv10, tv2);

auto layer_norm_outs = layer_norm(tv11, 1, tv0, tv1, new Double(1e-5));
auto layer_norm_outs =
layer_norm(tv11, 1, tv0, tv1, IrBuilder::create<Double>(1e-5));
auto tv14 = layer_norm_outs.output;
auto tv21 = layer_norm_outs.mean;
auto tv26 = layer_norm_outs.invstd;
Expand Down Expand Up @@ -481,7 +483,7 @@ static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) {
tv1 = castOp(DataType::Float, tv1);
tv8 = castOp(DataType::Float, tv8);
}
auto d36 = mul(new Double(1.0), tv1->axis(2)->extent());
auto d36 = mul(IrBuilder::create<Double>(1.0), tv1->axis(2)->extent());
auto d47 = unaryOp(UnaryOpType::Reciprocal, d36);

auto tv9 = broadcast(tv5, {true, true, false});
Expand Down Expand Up @@ -583,7 +585,7 @@ static void setupBiasDropoutAddLayernormBwd3(Fusion* fusion, DataType dtype) {
}

// Uncertain this is the right value, but going for it anyways
auto d34 = div(new Double(1.0), tv0->axis(2)->extent());
auto d34 = div(IrBuilder::create<Double>(1.0), tv0->axis(2)->extent());

auto tv25 = mul(tv21, tv0);
auto tv26 = mul(tv25, d34);
Expand Down
19 changes: 10 additions & 9 deletions benchmarks/cpp/nvfuser/gelu_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>

Expand Down Expand Up @@ -41,23 +42,23 @@ static void setupFusion(Fusion* fusion) {
auto t5 = castOp(DataType::Float, t4);
auto t6 = broadcast(t3, {true, true, false});
auto t7 = add(t6, t5);
auto t8 = mul(t7, new Double(k_079));
auto t9 = mul(t7, new Double(k_004));
auto t8 = mul(t7, IrBuilder::create<Double>(k_079));
auto t9 = mul(t7, IrBuilder::create<Double>(k_004));
auto t10 = mul(t9, t7);
auto t11 = add(t10, new Int(1));
auto t11 = add(t10, IrBuilder::create<Int>(1));
auto t12 = mul(t8, t11);
auto t13 = unaryOp(UnaryOpType::Tanh, t12);
auto t14 = mul(t7, new Double(0.5));
auto t14 = mul(t7, IrBuilder::create<Double>(0.5));
auto t15 = mul(t13, t13);
auto t16 = unaryOp(UnaryOpType::Neg, t15);
auto t17 = add(t16, new Int(1));
auto t18 = mul(t7, new Double(k_010));
auto t17 = add(t16, IrBuilder::create<Int>(1));
auto t18 = mul(t7, IrBuilder::create<Double>(k_010));
auto t19 = mul(t18, t7);
auto t20 = add(t19, new Double(k_079));
auto t20 = add(t19, IrBuilder::create<Double>(k_079));
auto t21 = mul(t17, t20);
auto t22 = mul(t14, t21);
auto t23 = add(t13, new Int(1));
auto t24 = mul(t23, new Double(0.5));
auto t23 = add(t13, IrBuilder::create<Int>(1));
auto t24 = mul(t23, IrBuilder::create<Double>(0.5));
auto t25 = add(t22, t24);
auto t26 = mul(t25, t1);

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/heuristic_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -129,7 +130,7 @@ static auto getLayerForwardNormRuntime(
Fusion& fusion = *fusion_ptr.get();

const float kEps = 1e-5;
Double* eps_ptr = new Double(kEps);
Double* eps_ptr = IrBuilder::create<Double>(kEps);

auto input = makeSymbolicTensor(shape.size());
fusion.addInput(input);
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/heuristic_lookup.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -129,7 +130,7 @@ static auto getLayerForwardNormRuntime(
Fusion& fusion = *fusion_ptr.get();

const float kEps = 1e-5;
Double* eps_ptr = new Double(kEps);
Double* eps_ptr = IrBuilder::create<Double>(kEps);

auto input = makeSymbolicTensor(shape.size());
fusion.addInput(input);
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/cpp/nvfuser/instance_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
Expand Down Expand Up @@ -39,8 +40,8 @@ static void setupInstanceNorm(Fusion* fusion, DataType dtype) {
const bool kTraining = true;
const float kMomentum = 0.1;
const float kEps = 1e-5;
auto momentum_ptr = new Double(kMomentum);
auto eps_ptr = new Double(kEps);
auto momentum_ptr = IrBuilder::create<Double>(kMomentum);
auto eps_ptr = IrBuilder::create<Double>(kEps);

auto norm = instance_norm(
input,
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/layer_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand All @@ -24,7 +25,7 @@ static void setupLayerNorm(Fusion* fusion, DataType dtype) {
const int kReductionAxis = 1;
const float kEps = 1e-5;

Double* eps_ptr = new Double(kEps);
Double* eps_ptr = IrBuilder::create<Double>(kEps);

// setup fusion
auto input = makeContigTensor(2, dtype);
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/layer_norm_backward.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand All @@ -22,7 +23,7 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);

const int kReductionAxis = 1;
Double* eps_ptr = new Double(1e-5);
Double* eps_ptr = IrBuilder::create<Double>(1e-5);

// setup fusion
auto grad_out = makeContigTensor(2, dtype);
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/shape_inference.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -151,7 +152,7 @@ static auto getLayerForwardNormRuntime(
Fusion& fusion = *fusion_ptr.get();

const float kEps = 1e-5;
Double* eps_ptr = new Double(kEps);
Double* eps_ptr = IrBuilder::create<Double>(kEps);

auto input = makeSymbolicTensor(shape.size());
fusion.addInput(input);
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/cpp/nvfuser/softmax_dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
Expand Down Expand Up @@ -35,7 +36,7 @@ static void setupSoftmaxDropout(
auto attention_scores = makeContigTensor(4, dtype);
auto attention_mask = makeContigTensor(4, dtype);

Double* divisor = new Double();
Double* divisor = IrBuilder::create<Double>();

fusion->addInput(attention_scores);
fusion->addInput(attention_mask);
Expand All @@ -49,8 +50,8 @@ static void setupSoftmaxDropout(
attention_scores = div(attention_scores, divisor);
attention_scores = add(attention_scores, attention_mask);
auto attention_probs = softmax(attention_scores, kReductionAxis);
auto prob = new Double(kDropoutProbability);
auto scale = new Double(kScale);
auto prob = IrBuilder::create<Double>(kDropoutProbability);
auto scale = IrBuilder::create<Double>(kScale);
auto dropout_results = dropout(attention_probs, prob, scale);
auto output = dropout_results.output;

Expand Down
10 changes: 5 additions & 5 deletions benchmarks/cpp/nvfuser/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ std::string toString(ReductionParams rparams) {
if (rparams.schedule_3D) {
ss << "3D Schedule // "
<< "Outer Reduction: "
<< (rparams.cross_block_outer_reduce ? "cross block / " : "")
<< (rparams.cross_grid_outer_reduce ? "cross grid / " : "")
<< (rparams.cross_block_outer_reduction ? "cross block / " : "")
<< (rparams.cross_grid_outer_reduction ? "cross grid / " : "")
<< (rparams.split_grid_dim_outer_reduction ? "split grid dim / " : "");
if (rparams.batches_per_block_outer_reduction > 1 ||
rparams.persistent_kernel) {
Expand All @@ -38,17 +38,17 @@ std::string toString(ReductionParams rparams) {
}

ss << " // Inner Reduction Domain: "
<< (rparams.cross_block_inner_reduce ? "cross block reduction / " : "")
<< (rparams.cross_block_inner_reduction ? "cross block reduction / " : "")
<< (rparams.pad_inner_reduction_to_warp ? "pad to warp / " : "")
<< (rparams.cross_grid_inner_reduce ? "cross grid reduction / " : "");
<< (rparams.cross_grid_inner_reduction ? "cross grid reduction / " : "");

if (rparams.batches_per_block_inner_reduction > 1 ||
rparams.persistent_kernel) {
ss << "persistent batch - " << rparams.batches_per_block_inner_reduction
<< " / ";
}

ss << (rparams.cross_grid_inner_reduce &&
ss << (rparams.cross_grid_inner_reduction &&
rparams.split_grid_dim_inner_reduction
? "split grid dimension / "
: "")
Expand Down
Loading

0 comments on commit 72a72d0

Please sign in to comment.