-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support broadcasts of predicated tensors within thread blocks #100
Changes from all commits
5b43f04
74f70a3
4c272ed
719f94f
0bff269
13cb501
a242404
636fb67
671ecf6
9bde671
6a9e179
ab48007
fc3c6d8
e0d9894
8b9482e
96b7e99
a2a4573
80a96c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h> | ||
#include <torch/csrc/jit/codegen/cuda/fusion.h> | ||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h> | ||
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h> | ||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h> | ||
|
||
#include <iostream> | ||
|
||
|
@@ -478,15 +480,50 @@ void IRPrinter::handle(const ReductionOp* rop) { | |
} | ||
|
||
void IRPrinter::handle(const BroadcastOp* bop) { | ||
indent(); | ||
handle(bop->out()); | ||
os << "\n"; | ||
indent_size++; | ||
indent(); | ||
os << " = "; | ||
handle(bop->in()); | ||
indent_size--; | ||
os << ";\n"; | ||
// Check if we've lowered yet. | ||
bool lowered = bop->out()->getValType() == ValType::TensorIndex; | ||
if (!lowered) { | ||
os << bop->out() << " = broadcast( " << bop->in() << " )\n"; | ||
return; | ||
} | ||
|
||
const ir_utils::ParallelTypeBitmap domains = | ||
ir_utils::getParallelBroadcastDomains(bop, getThreadPredicateMap()); | ||
const bool thread_x = domains.get(ParallelType::TIDx); | ||
const bool thread_y = domains.get(ParallelType::TIDy); | ||
const bool thread_z = domains.get(ParallelType::TIDz); | ||
const bool block_x = domains.get(ParallelType::BIDx); | ||
const bool block_y = domains.get(ParallelType::BIDy); | ||
const bool block_z = domains.get(ParallelType::BIDz); | ||
|
||
const bool grid_broadcast_needed = block_x || block_y || block_z; | ||
const bool block_broadcast_needed = thread_x || thread_y || thread_z; | ||
|
||
TORCH_INTERNAL_ASSERT( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also have this check in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It's checked in |
||
!grid_broadcast_needed, "Parallel broadcast across blocks not supported"); | ||
|
||
if (block_broadcast_needed) { | ||
indent(); | ||
os << "broadcast::blockBroadcast<"; | ||
os << (thread_x ? "true" : "false") << ", "; | ||
os << (thread_y ? "true" : "false") << ", "; | ||
os << (thread_z ? "true" : "false"); | ||
os << ">("; | ||
handle(bop->out()); | ||
os << ", "; | ||
handle(bop->in()); | ||
os << ");\n"; | ||
} else { | ||
indent(); | ||
handle(bop->out()); | ||
os << "\n"; | ||
indent_size++; | ||
indent(); | ||
os << " = "; | ||
handle(bop->in()); | ||
indent_size--; | ||
os << ";\n"; | ||
} | ||
} | ||
|
||
void IRPrinter::handle(const ForLoop* fl) { | ||
|
@@ -640,6 +677,14 @@ void IRPrinter::printKernel( | |
os << "}\n"; | ||
} | ||
|
||
const ThreadPredicateMap& IRPrinter::getThreadPredicateMap() { | ||
if (thread_predicates_ == nullptr) { | ||
Fusion* fusion = FusionGuard::getCurFusion(); | ||
thread_predicates_ = std::make_unique<ThreadPredicateMap>(fusion); | ||
} | ||
return *thread_predicates_; | ||
} | ||
|
||
std::ostream& operator<<(std::ostream& os, const Statement* stmt) { | ||
IRPrinter p(os); | ||
p.handle(stmt); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
#include <torch/csrc/WindowsTorchApiMacro.h> | ||
|
||
#include <torch/csrc/jit/codegen/cuda/dispatch.h> | ||
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h> | ||
|
||
#include <iostream> | ||
|
||
|
@@ -126,6 +127,11 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch { | |
void printKernel( | ||
const std::vector<Expr*>& exprs, | ||
const std::string& kernel_name); | ||
|
||
private: | ||
std::unique_ptr<ThreadPredicateMap> thread_predicates_; | ||
|
||
const ThreadPredicateMap& getThreadPredicateMap(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please separate private methods from private data into separate private sections, ex.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rationale is to be able to locate the state easily, instead of having it mixed with methods. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I agree here. It just has one field and one method. To make it a little easier to separate, I added an empty line between them. |
||
}; | ||
|
||
TORCH_CUDA_API std::ostream& operator<<( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
#include <torch/csrc/jit/codegen/cuda/dispatch.h> | ||
|
||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h> | ||
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
namespace fuser { | ||
|
@@ -40,7 +42,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { | |
|
||
// Predicates from ThreadPredicates that we will extend to reduction buffer | ||
// initialization | ||
std::unordered_map<const TensorView*, Bool*>& thread_predicates_; | ||
ThreadPredicateMap& thread_predicates_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can't have const as it can be modified in |
||
|
||
// Create, place, and return the allocation for tv | ||
Expr* pushAlloc(TensorView*); | ||
|
@@ -71,16 +73,14 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { | |
// Run the pass and accumulate output in lowered_exprs | ||
void generate(const std::vector<Expr*>& exprs); | ||
|
||
LoopNestGenerator( | ||
Fusion* _fusion, | ||
std::unordered_map<const TensorView*, Bool*>& _thread_predicates) | ||
LoopNestGenerator(Fusion* _fusion, ThreadPredicateMap& _thread_predicates) | ||
: fusion_(_fusion), thread_predicates_(_thread_predicates) {} | ||
|
||
public: | ||
static std::vector<Expr*> getLoopNest( | ||
Fusion* fusion, | ||
std::vector<Expr*> exprs, | ||
std::unordered_map<const TensorView*, Bool*>& thread_predicates) { | ||
ThreadPredicateMap& thread_predicates) { | ||
FusionGuard fg(fusion); | ||
LoopNestGenerator lng(fusion, thread_predicates); | ||
lng.generate(exprs); | ||
|
@@ -90,4 +90,4 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { | |
|
||
} // namespace fuser | ||
} // namespace jit | ||
} // namespace torch | ||
} // namespace torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we follow more closely to ReductionOp to first check if it's been lowered to figure out how we're going to print it?
It looks to me like you're going to print
broadcast::blockBroadcast<...
even if it hasn't been lowered.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. That's right. Fixed.