Skip to content

Commit

Permalink
Indexing refactor stage 1: remove reference tensor creation in all te…
Browse files Browse the repository at this point in the history
…nsor indexing logic (#1690)

Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>
  • Loading branch information
3 people committed Jun 25, 2022
1 parent c8b4f42 commit 97d3b84
Show file tree
Hide file tree
Showing 25 changed files with 2,013 additions and 426 deletions.
102 changes: 101 additions & 1 deletion compute_at_map.cpp
Expand Up @@ -226,7 +226,8 @@ void IterDomainGraph::initializeId(
}
}

ComputeAtMap::ComputeAtMap(Fusion* fusion) : id_graph_(fusion) {
ComputeAtMap::ComputeAtMap(Fusion* fusion)
: id_graph_(fusion), fusion_(fusion) {
build(fusion);
}

Expand Down Expand Up @@ -257,6 +258,105 @@ void ComputeAtMap::validateAndPropagatePType() {
}
}

void ComputeAtMap::allocateIndexVariables() {
// Run through all disjoint sets registered in loop map,
// all lowered kir::ForLoop will correspond to one of the disjoint sets
// and we only need one index variable for each set.
for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) {
ParallelType ptype;
// first allocate thread and grid parallel indices:
// The validation pass will check that the parallel bindings within the
// loop nodes are consistent so all the loops within this disjoint set
// will be realized implicitly using parallel index variables.
if (std::any_of(
loop_disjoint_set->vector().begin(),
loop_disjoint_set->vector().end(),
[&ptype](IterDomain* id) {
if (id->isThread() &&
// Halo extended parallel loops currently are handled
// differently and an index variable would still
// be allocated in this case.
(GpuLower::current()->haloInfo().getExtent(id) == nullptr)) {
ptype = id->getParallelType();
return true;
}
return false;
})) {
loop_index_variable_map_[loop_disjoint_set.get()] =
NamedScalar::getParallelIndex(ptype);
continue;
}

// All loops in this set are non-parallel, non-concretized broadcast
// iterdomains, their "index variable" should be zero.
if (std::all_of(
loop_disjoint_set->vector().begin(),
loop_disjoint_set->vector().end(),
[](IterDomain* id) { return id->isBroadcast(); })) {
loop_index_variable_map_[loop_disjoint_set.get()] = fusion_->zeroVal();
continue;
}

// Allocate variable for the iterdomains:
auto concrete_loop_id_it = concrete_id_cache_.find(loop_disjoint_set);
TORCH_INTERNAL_ASSERT(
concrete_loop_id_it != concrete_id_cache_.end(),
"Concrete id not computed");

auto concrete_loop_id = concrete_loop_id_it->second;

// Need to allocate double buffered loop differently.
if (GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain(
concrete_loop_id)) {
// Allocate index variable for each stage of the double buffered loop.
double_buffered_loop_index_variable_map_[loop_disjoint_set.get()] =
std::make_unique<DoubleBufferIndices>(DoubleBufferIndices(
{{DoubleBufferLoopStage::Prolog,
IrBuilder::create<Int>(c10::nullopt)},
{DoubleBufferLoopStage::Main,
IrBuilder::create<Int>(c10::nullopt)},
{DoubleBufferLoopStage::Epilog,
IrBuilder::create<Int>(c10::nullopt)}}));
} else {
// Everything now should be serial concrete loops,
// we just allocate a loop index integer for each set of loops.
loop_index_variable_map_[loop_disjoint_set.get()] =
IrBuilder::create<Int>(c10::nullopt);
}
}
}

Val* ComputeAtMap::getIndexVariable(
IterDomain* id,
DoubleBufferLoopStage double_buffer_loop_stage) const {
TORCH_INTERNAL_ASSERT(
id_graph_.loopNodes().mappingExists(id),
"Index Variable: no index variable allocated as ",
id->toString(),
" is not registered in loop map");
const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id));

// Check if this loop was modified by double buffer pass.
bool is_double_buffer_iterdomain =
GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain(id);

if (is_double_buffer_iterdomain) {
// Use dedicated double buffer index variable if the loop is double buffer
// loop
if (double_buffer_loop_stage == DoubleBufferLoopStage::NotApplicable) {
// The double buffered loop stages are created after the loop nest
// lowering phase so this function will be querried before the double
// buffer pass. At that point, no forloop has any double buffer
// stage defined, and we just default to using the main stage index.
double_buffer_loop_stage = DoubleBufferLoopStage::Main;
}
return double_buffered_loop_index_variable_map_.at(loop_set)->at(
double_buffer_loop_stage);
} else {
return loop_index_variable_map_.at(loop_set);
}
}

bool ComputeAtMap::areMapped(
IterDomain* id0,
IterDomain* id1,
Expand Down
49 changes: 49 additions & 0 deletions compute_at_map.h
Expand Up @@ -112,6 +112,8 @@ class TORCH_CUDA_CU_API IterDomainGraph {

class TrivialReductionInfo;

using DoubleBufferIndices = std::unordered_map<DoubleBufferLoopStage, Int*>;

class TORCH_CUDA_CU_API ComputeAtMap {
public:
ComputeAtMap() = delete;
Expand All @@ -122,6 +124,25 @@ class TORCH_CUDA_CU_API ComputeAtMap {
//! all IterDomains in the disjoint set to that PType.
void validateAndPropagatePType();

//! Run through disjoint sets in the LOOP map and allocate the index
//! variable for the associated for loop that will be generated
//! for each disjoint sets in the loop map. This pre-allocation makes
//! 2 key assumptions about computeAt map that would very likely be
//! long term invariant:
//! 1. All kir::forloop created in the lowering pass should belong
//! to one of the disjoint sets in loop map.
//! 2. The lowering pass will *never* create a loop nest with 2
//! different nesting levels mapped together, i.e. the case below
//! never occurs:
//! for i in IterDomain1
//! for j in IterDomain2
//! ...
//! With loop_map.areMapped(IterDomain1, IterDomain2) == true.
//! Under this condition, we can pre-allocate all required index
//! variable integers before creating any kir::forloop, and this
//! would help optimizing the generated integer math for indexing.
void allocateIndexVariables();

//! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode
bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const;

Expand Down Expand Up @@ -151,6 +172,16 @@ class TORCH_CUDA_CU_API ComputeAtMap {
//! Get the ID sets for a provided IdMappingMode
const DisjointSets<IterDomain*>& getIdSets(IdMappingMode mode) const;

//! Returns the pre-allocated index variable integer used in
//! the kir::ForLoop corresponding to the given IterDomain.
//! this interface is only valid if the ID has a loop mapping,
//! ca_map will throw exceptions if given iterdomain doesn't
//! have a loop map entry.
Val* getIndexVariable(
IterDomain* id,
DoubleBufferLoopStage double_buffer_loop_stage =
DoubleBufferLoopStage::NotApplicable) const;

private:
// Build id_graph_
void build(Fusion* fusion);
Expand Down Expand Up @@ -178,6 +209,24 @@ class TORCH_CUDA_CU_API ComputeAtMap {
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
IterDomain*>
concrete_id_cache_;

//! Allocated Loop index variable through the CA map.
//! only valid for disjoint sets on the loop ca map.
std::unordered_map<const VectorOfUniqueEntries<IterDomain*>*, Val*>
loop_index_variable_map_;

//! Allocated loop indices for double buffer loop.
//! only valid for disjoint sets on the loop ca map
//! that have double buffer-ed iterdomains.
using DoubleBufferIndicesPtr = std::unique_ptr<DoubleBufferIndices>;
std::unordered_map<
const VectorOfUniqueEntries<IterDomain*>*,
DoubleBufferIndicesPtr>
double_buffered_loop_index_variable_map_;

// Shortcut to access the fusion this computeAt map was
// built from.
Fusion* fusion_;
};

} // namespace cuda
Expand Down

0 comments on commit 97d3b84

Please sign in to comment.