Skip to content

Commit

Permalink
Allow IrBuilder to have Int and Int32 inputs in binary expressions (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Nov 30, 2022
1 parent 95a28c3 commit ed2c040
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 28 deletions.
41 changes: 24 additions & 17 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,11 @@ c10::optional<IterDomain*> getMaybeIndexedIdToHoist(
indexing.indexMap().find(contig_id_it->second) !=
indexing.indexMap().end(),
"Invalid contig index: ",
contig_id_it->second->toString());
contig_id_it->second->toString(),
", root ID: ",
root_id->toString(),
", TV: ",
tv->toString());

return indexed_id;
}
Expand Down Expand Up @@ -1406,7 +1410,15 @@ Val* hoistProducerIndex(
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*> initial_loop_index_map,
const std::vector<kir::ForLoop*>& loops,
Val* index) {
Val* index,
bool is_overriden_index) {
if (is_overriden_index) {
// do not hoist overridden index. It is used by
// select/index_select, so IterDomain equivalence does not mean
// the same index math
return index;
}

auto maybe_indexed_producer_id = getMaybeIndexedIdToHoist(
producer_root_id, producer_tv, producer_indexing, index);

Expand Down Expand Up @@ -1573,7 +1585,8 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(

Val* root_ind = nullptr;
auto override_it = override_index.find(root_dom[i]);
if (override_it != override_index.end()) {
const bool is_overriden = override_it != override_index.end();
if (is_overriden) {
root_ind = override_it->second;
} else if (
producer_indexing.indexMap().find(root_dom[i]) !=
Expand Down Expand Up @@ -1602,14 +1615,11 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
producer_indexing_from_idgraph.resolved_loop_domains,
producer_indexing_from_idgraph.initial_concrete_index_map,
loops,
root_ind);
root_ind,
is_overriden);

root_ind = getProducerIndexWithHalo(
producer_tv,
i,
root_ind,
consumer_tv,
override_index.count(root_dom[i]));
producer_tv, i, root_ind, consumer_tv, is_overriden);

root_ind = getProducerIndexWithGather(
root_ind,
Expand Down Expand Up @@ -1822,9 +1832,9 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
root_dom[i]->toString());

auto override_it = override_index.find(root_dom[i]);
const bool is_overriden = override_it != override_index.end();
auto root_ind_i =
(override_it != override_index.end() ? override_it->second
: index_map.at(root_dom[i]));
is_overriden ? override_it->second : index_map.at(root_dom[i]);

// index hoist must be done before the adjustments for halo
root_ind_i = hoistProducerIndex(
Expand All @@ -1836,14 +1846,11 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
producer_indexing_from_idgraph.resolved_loop_domains,
producer_indexing_from_idgraph.initial_concrete_index_map,
loops,
root_ind_i);
root_ind_i,
is_overriden);

root_ind_i = getProducerIndexWithHalo(
producer_tv,
i,
root_ind_i,
consumer_tv,
override_index.count(root_dom[i]));
producer_tv, i, root_ind_i, consumer_tv, is_overriden);

root_ind_i = getProducerIndexWithGather(
root_ind_i,
Expand Down
35 changes: 28 additions & 7 deletions torch/csrc/jit/codegen/cuda/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,34 @@ Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
TORCH_CHECK(
lhs != nullptr && rhs != nullptr,
"Either lhs or rhs is a nullptr in newArithmeticExpr.");
TORCH_CHECK(
lhs->dtype() == rhs->dtype(),
"Incompatible operand types: ",
lhs->dtype(),
" and ",
rhs->dtype());
auto result = newScalar(lhs->dtype());

auto dtype = lhs->dtype();

// In principle, we should keep these IrBuilder functions as
// simple as possible since they are just used by the lowering for
// scalar computations. We should enforce strict typing with no
// implicit type promotion unless required. However, for
// int and int64_t, our usages are pretty loose in many places. Originally we
// only had int64_t, then we added nvfuser_index_t and replaced the types of
// some of the values from int64_t to int just at the beginning of lowering.
// This resulted in inconsistent usages of integer types in many places, and
// fixing all of them to make everything consistent would be a lot of work
// than just allowing the integer type promotion for the two inputs as below.
// Note that this is only needed for integer types. See also PR #2228.
if (lhs->dtype() != rhs->dtype()) {
if ((lhs->dtype() == DataType::Int && rhs->dtype() == DataType::Int32) ||
(lhs->dtype() == DataType::Int32 && rhs->dtype() == DataType::Int)) {
dtype = DataType::Int;
} else {
TORCH_CHECK(
false,
"Incompatible operand types: ",
lhs->dtype(),
" and ",
rhs->dtype());
}
}
auto result = newScalar(dtype);
IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return result;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ TensorIndex::TensorIndex(
std::all_of(
indices.begin(),
indices.end(),
[](Val* v) { return v->dtype() == DataType::Int; }),
[](Val* v) { return isIntegralType(v->dtype()); }),
"Cannot index with a value other than an int.");
indices_.erase(
std::remove_if(
Expand Down
18 changes: 15 additions & 3 deletions torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,27 @@ void IndexLowering::handle(const TernaryOp* top) {
}

void IndexLowering::handle(const IndexSelectOp* sop) {
const auto indices = lowerSrcIndex(sop->input(1), sop->output(0));
auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0));
auto lowered_index_cast = lowered_index;

// If the type of the index tensor is different from the kernel
// index type, promote it to the kernel index type
if (GpuLower::current()->kernel()->indexType() !=
sop->input(1)->getDataType().value()) {
lowered_index_cast =
IrBuilder::newScalar(GpuLower::current()->kernel()->indexType());
IrBuilder::create<UnaryOp>(
UnaryOpType::Cast, lowered_index_cast, lowered_index);
}

const std::unordered_map<IterDomain*, Val*> override_index = {
{sop->getSelectAxis(), indices}};
{sop->getSelectAxis(), lowered_index_cast}};
const auto lookup =
lowerSrcIndex(sop->input(0), sop->output(0), override_index);

const auto out = lowerDstIndex(sop->output(0));
pushBack(IrBuilder::create<IndexSelectOp>(
out, lookup, sop->dim(), sop->getSelectAxis(), indices));
out, lookup, sop->dim(), sop->getSelectAxis(), lowered_index));
GpuLower::current()->propagateExprInfo(sop, back());
}

Expand Down

0 comments on commit ed2c040

Please sign in to comment.