Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Nov 16, 2021

This is a second attempt to fix an indexing problem that is first reported in PR #1243. See that PR as well for the context of this problem.

Also, the repro of issue #1074 is now working (FusionNonDivisibleSplit2).

This PR addresses an index linearlization problem that was exposed by view and strided gather operations, both of which are based on rFactor. The root cause of the problem is non-divisible splits. When indexing with such splits, the domain that is split gets a larger extent due to the ceilDiv of the outer domain. This is propagated back from leaf to parent domains in IndexCompute::handle(Split*). It turns out that there are consistency problems when to propagate. The most sensible way is likely to completely stop doing so, but it also means intermediate domains have to be predicated if split by a non-divisible factor.

Also, predicating such domains doesn't work if vectorization is involved. If an output domain of a split is vectorized, the split must be divisible. Divisibility is validated at the kernel launch time.

The main logic is in NonDivisibleSplitInfo. It scans a fusion to locate any Split expression that must be either predicated or validated. If an IterDomain is run-time validated, it's not predicated as it's unnecessary (NonDivisibleSplitInfo::removeRedundancy()).

getReferenceRootPredicates is extended to generate predicates for those domains that may be split with a non-divisible factor. Currently, it finds root or merged contiguous domains to predicate. I extended that to include non-divisible split domains (getNonDivisibleDomainsToPredicate). While index_compute.cpp may appear to have many changes, most of them don't change the existing logic. For example, the change in getStartAndStopOffsets is not to do adjustments for shift and gather operations when predicating non-divisible split domains.

All of the C++ tests are passing, but test_native_layer_norm_half is failing due to a runtime validation of a non-divisible domain with vectorization. The domain to split is:

i515 = T0.size[1] * T0.size[2]
~ i518 = T0.size[0] * i515
~ i521 = ceilDiv(i518, 4)
i521

This i521 domain is split by 2 and the inner domain of size 2 is vectorized. The test fails with T0 being [17, 10, 10] since 17 * 10 * 10 / 4 isn't divisible by 2. Without the runtime check, an alignment error occurs. Since @csarofeen mentioned he's working on changing scheduling algorithms for vectorization, I've done nothing on the failure yet.

Closes #1074

&fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__);
}

TEST(NVFuserTest, FusionRfactorPredication2_CUDA) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a test created for #1250 but removed as it was failing due to the extent problem.

@naoyam naoyam marked this pull request as ready for review November 16, 2021 18:41
@naoyam naoyam changed the title [WIP] Make non-divisible splits do not change extents used in indexing Make non-divisible splits not change extents used in indexing Nov 16, 2021
@naoyam naoyam requested a review from csarofeen November 16, 2021 18:45
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we clean up the predication logic to always use an explicit list of which iteration domains need to be predicated? I understand why this was only done for non-divisible-splits, but it seems to me we could have a pass that would specify all the domains that need to be predicated, including if root domains need predication, or just the contiguous merge domains. Maybe it would be good to try and unify this logic a bit more.

Looks really good overall, approving!

auto split_factor = expr_eval.evaluate(extent_factor.second);
TORCH_INTERNAL_ASSERT(
input_extent.has_value(),
"Extent not possible to evaluate: ",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you mention in these assert messages this is from vectorizedSplit validation. A little more context in these messages will be helpful.

//! Split expressions whose divisbility must be validated at run time
std::unordered_set<Split*> splits_to_validate_;

//! Temporary used for analyzing each tensor
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Temporarily

private:
//! Split expressions whose input domain must be predicated
std::unordered_map<TensorView*, std::vector<Split*>> splits_to_predicate_;
//! Split expressions whose divisbility must be validated at run time
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: divisibility

return partial_split_map_;
}

auto& nonDivisibleSplitInfo() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should non-const references in this header be private? Maybe we should check this as a follow up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as GpuLower is passed as a const object after lowering, I think it should be fine, but yes, this class is growing and may need some refactoring.

// supprot shift. If contig_id is not root, nothing is required to
// do for shift as shift-related domains are excluded from
// contig domains.
// supprot shift. If contig_id a merged non-root domain, nothing
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: support shift. If a contig_id is a merged non-root domain, nothing

@naoyam
Copy link
Collaborator Author

naoyam commented Nov 17, 2021

Should we clean up the predication logic to always use an explicit list of which iteration domains need to be predicated? I understand why this was only done for non-divisible-splits, but it seems to me we could have a pass that would specify all the domains that need to be predicated, including if root domains need predication, or just the contiguous merge domains. Maybe it would be good to try and unify this logic a bit more.

Agreed. We have contig IDs, non-divisible split IDs, and non-exact parallelized IDs. The last one is added in getInlinePredicate, but there's no reason that all three can't be done uniformly. I'll revisit later.

@csarofeen csarofeen deleted the non-divisible-split branch January 22, 2022 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Invalid striding

3 participants