Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,17 @@ void IndexCompute::handle(Split* split) {
} else {
index_map_[in_id] = ir_builder.addExpr(
ir_builder.mulExpr(outer_ind, getExtent(inner_id)), inner_ind);

// The extent of a root axis should be only updated when its
// allocation is partial, i.e., zero_merged_in is true. See issue
// #1016 and the FusionIssue1016 test.
if (split->in()->definition() != nullptr) {
auto def = split->in()->definition();
if (def->isA<Split>() && def->as<Split>()->viewSplit()) {
return;
}
}

if (split->in()->definition() != nullptr || zero_merged_in) {
extent_map_[in_id] =
ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id));
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/index_reference_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void IndexReferenceReplay::handle(Split* split) {
ref_in,
split->factor(),
split->innerSplit(),
split->viewSplit(),
split->startOffset(),
split->stopOffset());

Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ class TORCH_CUDA_CU_API TensorView : public Val {
int axis,
unsigned int factor,
bool inner_split = true,
bool trim_out_of_bounds = false);
bool trim_out_of_bounds = false,
bool view_split = false);

// Split "axis" into 2 axes where the inner axes is size of "factor"
// and outer axis is size axis.size() / factor. Factor can be a symbolic
Expand All @@ -273,7 +274,8 @@ class TORCH_CUDA_CU_API TensorView : public Val {
int axis,
Val* factor,
bool inner_split = true,
bool trim_out_of_bounds = false);
bool trim_out_of_bounds = false,
bool view_split = false);

// Merge axis_o and axis_i into 1 IterDomain
TensorView* merge(int axis_o, int axis_i);
Expand Down
54 changes: 33 additions & 21 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,29 @@ class TORCH_CUDA_CU_API IterDomain : public Val {

static IterDomain* merge(IterDomain* outer, IterDomain* inner);

//! start_offset and stop_offset defines partial split. Only root
//! domains are allowed to have non-zero start and stop offsets.
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
bool view_split,
Val* start_offset = nullptr,
Val* stop_offset = nullptr);

//! trim_out_of_bounds controls how the values outside start and stop
//! positions are treated. The option is only valid with root
//! domains as non-root domains do not have valid start and stop
//! positions.
//!
//! \param trim_out_of_bounds Trims [0, start_] and [-stop_offset_, extent_]
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
bool view_split,
bool trim_out_of_bounds);

bool isReduction() const {
return getIterType() == IterType::Reduction;
}
Expand Down Expand Up @@ -583,27 +606,6 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
friend ReplayTransformations;
friend IndexReferenceReplay;

//! start_offset and stop_offset defines partial split. Only root
//! domains are allowed to have non-zero start and stop offsets.
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
Val* start_offset = nullptr,
Val* stop_offset = nullptr);

//! trim_out_of_bounds controls how the values outside start and stop
//! positions are treated. The option is only valid with root
//! domains as non-root domains do not have valid start and stop
//! positions.
//!
//! \param trim_out_of_bounds Trims [0, start_] and [-stop_offset_, extent_]
static std::pair<IterDomain*, IterDomain*> split(
IterDomain* in,
Val* factor,
bool inner_split,
bool trim_out_of_bounds);

private:
//! Valid range is defined as [start:-stop_offset]
Val* const start_ = nullptr;
Expand Down Expand Up @@ -741,6 +743,7 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
int axis_,
Val* factor,
bool inner_split,
bool view_split,
bool trim_out_of_bounds = false);

// Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting
Expand Down Expand Up @@ -790,6 +793,7 @@ class TORCH_CUDA_CU_API Split : public Expr {
IterDomain* in,
Val* factor,
bool inner_split = true,
bool view_split = false,
Val* start_offset = nullptr,
Val* stop_offset = nullptr);

Expand All @@ -798,12 +802,15 @@ class TORCH_CUDA_CU_API Split : public Expr {
IterDomain* outer() const {
return outer_;
}

IterDomain* inner() const {
return inner_;
}

IterDomain* in() const {
return in_;
}

Val* factor() const {
return factor_;
}
Expand All @@ -812,6 +819,10 @@ class TORCH_CUDA_CU_API Split : public Expr {
return inner_split_;
}

bool viewSplit() const {
return view_split_;
}

Val* startOffset() const {
TORCH_INTERNAL_ASSERT(start_offset_ != nullptr);
return start_offset_;
Expand All @@ -833,6 +844,7 @@ class TORCH_CUDA_CU_API Split : public Expr {
IterDomain* const in_ = nullptr;
Val* const factor_ = nullptr;
bool inner_split_ = true;
bool view_split_ = false;
//! Start position of the input domain. Non-zero means partial
//! split. Elements until this offset are ignored.
Val* const start_offset_ = nullptr;
Expand Down
16 changes: 12 additions & 4 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
bool view_split,
Val* start_offset,
Val* stop_offset) {
TORCH_CHECK(
Expand Down Expand Up @@ -836,18 +837,21 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
in->getIterType(),
in->isRFactorProduct());

new Split(ido, idi, in, factor, inner_split, start_offset, stop_offset);
new Split(
ido, idi, in, factor, inner_split, view_split, start_offset, stop_offset);
return {ido, idi};
}

std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
bool view_split,
bool trim_out_of_bounds) {
auto start_offset = trim_out_of_bounds ? in->start() : nullptr;
auto stop_offset = trim_out_of_bounds ? in->stopOffset() : nullptr;
return IterDomain::split(in, factor, inner_split, start_offset, stop_offset);
return IterDomain::split(
in, factor, inner_split, view_split, start_offset, stop_offset);
}

// TODO: We should change parallelize interface to be on tensorview or at least
Expand Down Expand Up @@ -1159,6 +1163,7 @@ void TensorDomain::split(
int axis_,
Val* factor,
bool inner_split,
bool view_split,
bool trim_out_of_bounds) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
if (axis_ < 0)
Expand All @@ -1178,8 +1183,8 @@ void TensorDomain::split(
"Partial split is only allowed with root domains");
}

auto split_ids =
IterDomain::split(id, factor, inner_split, trim_out_of_bounds);
auto split_ids = IterDomain::split(
id, factor, inner_split, view_split, trim_out_of_bounds);
domain_.erase(domain_.begin() + axis_);
domain_.insert(domain_.begin() + axis_, split_ids.second);
domain_.insert(domain_.begin() + axis_, split_ids.first);
Expand Down Expand Up @@ -1362,6 +1367,7 @@ Split::Split(
IterDomain* in,
Val* factor,
bool inner_split,
bool view_split,
Val* start_offset,
Val* stop_offset)
: Expr(ExprType::Split),
Expand All @@ -1370,6 +1376,7 @@ Split::Split(
in_{in},
factor_{factor},
inner_split_{inner_split},
view_split_{view_split},
start_offset_{start_offset != nullptr ? start_offset : new Int(0)},
stop_offset_{stop_offset != nullptr ? stop_offset : new Int(0)} {
TORCH_INTERNAL_ASSERT(
Expand All @@ -1390,6 +1397,7 @@ Split::Split(const Split* src, IrCloner* ir_cloner)
in_(ir_cloner->clone(src->in_)),
factor_(ir_cloner->clone(src->factor_)),
inner_split_(src->inner_split_),
view_split_(src->view_split_),
start_offset_(ir_cloner->clone(src->start_offset_)),
stop_offset_(ir_cloner->clone(src->stop_offset_)) {}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/mutator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Statement* OptOutMutator::mutate(Split* s) {
return s;
}
FusionGuard::getCurFusion()->removeExpr(s);
return new Split(ot, inr, in, fact, s->innerSplit());
return new Split(ot, inr, in, fact, s->innerSplit(), s->viewSplit());
}

Statement* OptOutMutator::mutate(Merge* m) {
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/codegen/cuda/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ TensorView* TensorView::split(
int axis_,
Val* factor,
bool inner_split,
bool trim_out_of_bounds) {
bool trim_out_of_bounds,
bool view_split) {
// Only check things associated with axis, factor will be validated in
// IterDomain
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView");
Expand Down Expand Up @@ -281,16 +282,17 @@ TensorView* TensorView::split(
"Splitting an axis of non-Serial parallel type is not supported at this time."
" Parallelization strategy must be set after calling split.");

domain()->split(axis_, factor, inner_split, trim_out_of_bounds);
domain()->split(axis_, factor, inner_split, view_split, trim_out_of_bounds);
return this;
}

TensorView* TensorView::split(
int axis,
unsigned int factor,
bool inner_split,
bool trim_out_of_bounds) {
split(axis, new Int(factor), inner_split, trim_out_of_bounds);
bool trim_out_of_bounds,
bool view_split) {
split(axis, new Int(factor), inner_split, trim_out_of_bounds, view_split);
return this;
}

Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/codegen/cuda/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ void ReplayTransformations::handle(Split* s) {

// Replay the split onto mapped
auto outs = IterDomain::split(
mapped, s->factor(), s->innerSplit(), s->startOffset(), s->stopOffset());
mapped,
s->factor(),
s->innerSplit(),
s->viewSplit(),
s->startOffset(),
s->stopOffset());
// Remove mapped from the leaf IDs
leaf_ids_.erase(mapped);

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ReplaySelf : public ReplayTransformations {
mapped,
s->factor(),
s->innerSplit(),
s->viewSplit(),
s->startOffset(),
s->stopOffset());

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ReplayRFactor : public ReplayTransformations {
true);

// Generate the split node
new Split(ido, idi, mapped, s->factor(), s->innerSplit());
new Split(ido, idi, mapped, s->factor(), s->innerSplit(), s->viewSplit());

// Remove mapped id from leaf IDs
leaf_ids_.erase(mapped);
Expand Down