Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support broadcasts of predicated tensors within thread blocks #100

Merged
merged 18 commits into from
Jul 8, 2020
Merged
401 changes: 327 additions & 74 deletions test/cpp/jit/test_gpu.cpp

Large diffs are not rendered by default.

144 changes: 74 additions & 70 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,76 +97,80 @@ namespace jit {
_(FusionAliasing)

#if defined(USE_CUDA)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
_(CompleteArgumentSpec) \
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
_(GPU_IrGraphGenerator) \
_(GPU_FusionDispatch) \
_(GPU_FusionClear) \
_(GPU_FusionCopy) \
_(GPU_FusionMove) \
_(GPU_FusionSimpleArith) \
_(GPU_FusionExprEvalConstants) \
_(GPU_FusionExprEvalBindings) \
_(GPU_FusionExprEvalBasic) \
_(GPU_FusionExprEvalComplex) \
_(GPU_FusionExprEvalPostLower) \
_(GPU_FusionSimpleTypePromote) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
_(GPU_FusionTopoSort) \
_(GPU_FusionTensor) \
_(GPU_FusionTVSplit) \
_(GPU_FusionTVMerge) \
_(GPU_FusionTVReorder) \
_(GPU_FusionEquality) \
_(GPU_FusionReplaceAll) \
_(GPU_FusionParser) \
_(GPU_FusionDependency) \
_(GPU_FusionCodeGen) \
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop) \
_(GPU_FusionLoopUnroll) \
_(GPU_FusionUnaryOps) \
_(GPU_FusionBinaryOps) \
_(GPU_FusionTernaryOps) \
_(GPU_FusionCompoundOps) \
_(GPU_FusionCastOps) \
_(GPU_FusionAdvancedComputeAt) \
_(GPU_FusionScalarInputs) \
_(GPU_FusionRFactorReplay) \
_(GPU_FusionReduction) \
_(GPU_FusionReduction2) \
_(GPU_FusionReduction3) \
_(GPU_FusionReduction4) \
_(GPU_FusionReduction5) \
_(GPU_FusionReductionTFT) \
_(GPU_FusionSimpleBCast) \
_(GPU_FusionSimpleGemm) \
_(GPU_FusionSoftmax) \
_(GPU_FusionSoftmaxComputeAt) \
_(GPU_FusionGridReduction1) \
_(GPU_FusionGridReduction2) \
_(GPU_FusionGridReduction3dim1) \
_(GPU_FusionGridReduction3dim0) \
_(GPU_FusionGridReduction4) \
_(GPU_FusionGridReduction5) \
_(GPU_FusionGridReduction6) \
_(GPU_FusionNonRedAxisBind) \
_(GPU_FusionBCastInnerDim) \
_(GPU_FusionBCastReduce) \
_(GPU_FusionSplitBCast) \
_(GPU_FusionComputeAtExprOrder) \
_(GPU_FusionZeroDimComputeAt) \
_(GPU_FusionZeroDimBroadcast) \
_(GPU_FusionZeroDimReduction) \
_(GPU_FusionReductionMultiConsumer)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
_(CompleteArgumentSpec) \
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
_(GPU_IrGraphGenerator) \
_(GPU_FusionDispatch) \
_(GPU_FusionClear) \
_(GPU_FusionCopy) \
_(GPU_FusionMove) \
_(GPU_FusionSimpleArith) \
_(GPU_FusionExprEvalConstants) \
_(GPU_FusionExprEvalBindings) \
_(GPU_FusionExprEvalBasic) \
_(GPU_FusionExprEvalComplex) \
_(GPU_FusionExprEvalPostLower) \
_(GPU_FusionSimpleTypePromote) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
_(GPU_FusionTopoSort) \
_(GPU_FusionTensor) \
_(GPU_FusionTVSplit) \
_(GPU_FusionTVMerge) \
_(GPU_FusionTVReorder) \
_(GPU_FusionEquality) \
_(GPU_FusionReplaceAll) \
_(GPU_FusionParser) \
_(GPU_FusionDependency) \
_(GPU_FusionCodeGen) \
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop) \
_(GPU_FusionLoopUnroll) \
_(GPU_FusionUnaryOps) \
_(GPU_FusionBinaryOps) \
_(GPU_FusionTernaryOps) \
_(GPU_FusionCompoundOps) \
_(GPU_FusionCastOps) \
_(GPU_FusionAdvancedComputeAt) \
_(GPU_FusionScalarInputs) \
_(GPU_FusionRFactorReplay) \
_(GPU_FusionReduction) \
_(GPU_FusionReduction2) \
_(GPU_FusionReduction3) \
_(GPU_FusionReduction4) \
_(GPU_FusionReduction5) \
_(GPU_FusionReductionTFT) \
_(GPU_FusionSimpleBCast) \
_(GPU_FusionSimpleGemm) \
_(GPU_FusionSoftmax1D) \
_(GPU_FusionSoftmax1DNormalized) \
_(GPU_FusionSoftmax3D) \
_(GPU_FusionSoftmax3DNormalized) \
_(GPU_FusionSoftmaxComputeAt) \
_(GPU_FusionGridReduction1) \
_(GPU_FusionGridReduction2) \
_(GPU_FusionGridReduction3dim1) \
_(GPU_FusionGridReduction3dim0) \
_(GPU_FusionGridReduction4) \
_(GPU_FusionGridReduction5) \
_(GPU_FusionGridReduction6) \
_(GPU_FusionNonRedAxisBind) \
_(GPU_FusionBCastInnerDim) \
_(GPU_FusionBCastReduce) \
_(GPU_FusionSplitBCast) \
_(GPU_FusionComputeAtExprOrder) \
_(GPU_FusionZeroDimComputeAt) \
_(GPU_FusionZeroDimBroadcast) \
_(GPU_FusionZeroDimReduction) \
_(GPU_FusionReductionMultiConsumer) \
_(GPU_FusionBCastAfterReduce)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
63 changes: 54 additions & 9 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>

#include <iostream>

Expand Down Expand Up @@ -478,15 +480,50 @@ void IRPrinter::handle(const ReductionOp* rop) {
}

void IRPrinter::handle(const BroadcastOp* bop) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we follow more closely to ReductionOp to first check if it's been lowered to figure out how we're going to print it?

  bool lowered = rop->out()->getValType() == ValType::TensorIndex;
  if (!lowered) {
    os << rop->out() << " = reduction( " << rop->in()
       << ", op = " << rop->getReductionOpType()
       << ", initial value = " << rop->init() << " )\n";
    return;
  }

It looks to me like you're going to print broadcast::blockBroadcast<... even if it hasn't been lowered.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. That's right. Fixed.

indent();
handle(bop->out());
os << "\n";
indent_size++;
indent();
os << " = ";
handle(bop->in());
indent_size--;
os << ";\n";
// Check if we've lowered yet.
bool lowered = bop->out()->getValType() == ValType::TensorIndex;
if (!lowered) {
os << bop->out() << " = broadcast( " << bop->in() << " )\n";
return;
}

const ir_utils::ParallelTypeBitmap domains =
ir_utils::getParallelBroadcastDomains(bop, getThreadPredicateMap());
const bool thread_x = domains.get(ParallelType::TIDx);
const bool thread_y = domains.get(ParallelType::TIDy);
const bool thread_z = domains.get(ParallelType::TIDz);
const bool block_x = domains.get(ParallelType::BIDx);
const bool block_y = domains.get(ParallelType::BIDy);
const bool block_z = domains.get(ParallelType::BIDz);

const bool grid_broadcast_needed = block_x || block_y || block_z;
const bool block_broadcast_needed = thread_x || thread_y || thread_z;

TORCH_INTERNAL_ASSERT(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also have this check in IterDomain::parallelize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!grid_broadcast_needed, "Parallel broadcast across blocks not supported");

if (block_broadcast_needed) {
indent();
os << "broadcast::blockBroadcast<";
os << (thread_x ? "true" : "false") << ", ";
os << (thread_y ? "true" : "false") << ", ";
os << (thread_z ? "true" : "false");
os << ">(";
handle(bop->out());
os << ", ";
handle(bop->in());
os << ");\n";
} else {
indent();
handle(bop->out());
os << "\n";
indent_size++;
indent();
os << " = ";
handle(bop->in());
indent_size--;
os << ";\n";
}
}

void IRPrinter::handle(const ForLoop* fl) {
Expand Down Expand Up @@ -640,6 +677,14 @@ void IRPrinter::printKernel(
os << "}\n";
}

const ThreadPredicateMap& IRPrinter::getThreadPredicateMap() {
if (thread_predicates_ == nullptr) {
Fusion* fusion = FusionGuard::getCurFusion();
thread_predicates_ = std::make_unique<ThreadPredicateMap>(fusion);
}
return *thread_predicates_;
}

std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
IRPrinter p(os);
p.handle(stmt);
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>

#include <iostream>

Expand Down Expand Up @@ -126,6 +127,11 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
void printKernel(
const std::vector<Expr*>& exprs,
const std::string& kernel_name);

private:
std::unique_ptr<ThreadPredicateMap> thread_predicates_;

const ThreadPredicateMap& getThreadPredicateMap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please separate private methods from private data into separate private sections, ex.

...
private:
  ... private methods ...

private:
  ... private data ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale is to be able to locate the state easily, instead of having it mixed with methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I agree here. It just has one field and one method. To make it a little easier to separate, I added an empty line between them.

};

TORCH_CUDA_API std::ostream& operator<<(
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ std::pair<std::string, std::string> codeGeneration(Fusion* fusion) {
<< code_random_number_gen << "\n"
<< code_helper_funcs << "\n"
<< code_template_block_reduction << "\n"
<< code_template_grid_reduction << "\n";
<< code_template_grid_reduction << "\n"
<< code_template_block_broadcast << "\n";
std::stringstream cdg;
GPULower gpulw(fusion);
gpulw.printKernel(str_stream, kKernelName);
Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,53 @@ __device__ void gridReduce(T& out, T inp_val, Func reduction_op,
} // namespace reduction
)";

static auto code_template_block_broadcast = R"(
namespace broadcast {

template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
__host__ __device__ unsigned offset_of_source(const dim3& block_dim, const dim3& thread_idx) {
unsigned offset = 0;
if (!Z_THREAD)
offset = offset * block_dim.z + thread_idx.z;
if (!Y_THREAD)
offset = offset * block_dim.y + thread_idx.y;
if (!X_THREAD)
offset = offset * block_dim.x + thread_idx.x;
return offset;
}

/** Broadcasts within partitioned groups of threads.

X_THREAD: Broadcast from threadIdx.x == 0 if true
Y_THREAD: Broadcast from threadIdx.y == 0 if true
Z_THREAD: Broadcast from threadIdx.z == 0 if true
inp_val: Per-thread source value. Only valid when the thread is a source.
out: Per-thread output location
*/
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T>
__device__ void blockBroadcast(T& out, T inp_val) {

// Use worst case for memory.
__shared__ T shared_mem[1024];

const bool has_valid_data =
(!X_THREAD || threadIdx.x == 0) &&
(!Y_THREAD || threadIdx.y == 0) &&
(!Z_THREAD || threadIdx.z == 0);

const auto shared_offset = offset_of_source<X_THREAD, Y_THREAD, Z_THREAD>(blockDim, threadIdx);

if (has_valid_data)
shared_mem[shared_offset] = inp_val;

__syncthreads();

out = shared_mem[shared_offset];
}

} // namespace broadcast
)";

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::vector<Expr*> GPULower::getLoweredExprs() {
// Validate and make some minor modifications in preparation to generate code.
PrepareForLowering(fusion_);

auto preds = ThreadPredicates::compute(fusion_);
ThreadPredicateMap preds(fusion_);

// Run our passes keeping the lowered expressions and forwarding them.
auto loop_nests = LoopNestGenerator::getLoopNest(
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/jit/codegen/cuda/lower_loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <torch/csrc/jit/codegen/cuda/dispatch.h>

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>

namespace torch {
namespace jit {
namespace fuser {
Expand Down Expand Up @@ -40,7 +42,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {

// Predicates from ThreadPredicates that we will extend to reduction buffer
// initialization
std::unordered_map<const TensorView*, Bool*>& thread_predicates_;
ThreadPredicateMap& thread_predicates_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't have const as it can be modified in LoopNestGenerator::initReduction.


// Create, place, and return the allocation for tv
Expr* pushAlloc(TensorView*);
Expand Down Expand Up @@ -71,16 +73,14 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
// Run the pass and accumulate output in lowered_exprs
void generate(const std::vector<Expr*>& exprs);

LoopNestGenerator(
Fusion* _fusion,
std::unordered_map<const TensorView*, Bool*>& _thread_predicates)
LoopNestGenerator(Fusion* _fusion, ThreadPredicateMap& _thread_predicates)
: fusion_(_fusion), thread_predicates_(_thread_predicates) {}

public:
static std::vector<Expr*> getLoopNest(
Fusion* fusion,
std::vector<Expr*> exprs,
std::unordered_map<const TensorView*, Bool*>& thread_predicates) {
ThreadPredicateMap& thread_predicates) {
FusionGuard fg(fusion);
LoopNestGenerator lng(fusion, thread_predicates);
lng.generate(exprs);
Expand All @@ -90,4 +90,4 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {

} // namespace fuser
} // namespace jit
} // namespace torch
} // namespace torch
Loading