-
Notifications
You must be signed in to change notification settings - Fork 7
Make non-divisible splits not change extents used in indexing #1270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); | ||
| } | ||
|
|
||
| TEST(NVFuserTest, FusionRfactorPredication2_CUDA) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a test created for #1250 but removed as it was failing due to the extent problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we clean up the predication logic to always use an explicit list of which iteration domains need to be predicated? I understand why this was only done for non-divisible-splits, but it seems to me we could have a pass that would specify all the domains that need to be predicated, including if root domains need predication, or just the contiguous merge domains. Maybe it would be good to try and unify this logic a bit more.
Looks really good overall, approving!
| auto split_factor = expr_eval.evaluate(extent_factor.second); | ||
| TORCH_INTERNAL_ASSERT( | ||
| input_extent.has_value(), | ||
| "Extent not possible to evaluate: ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you mention in these assert messages this is from vectorizedSplit validation. A little more context in these messages will be helpful.
| //! Split expressions whose divisbility must be validated at run time | ||
| std::unordered_set<Split*> splits_to_validate_; | ||
|
|
||
| //! Temporary used for analyzing each tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Temporarily
| private: | ||
| //! Split expressions whose input domain must be predicated | ||
| std::unordered_map<TensorView*, std::vector<Split*>> splits_to_predicate_; | ||
| //! Split expressions whose divisbility must be validated at run time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: divisibility
| return partial_split_map_; | ||
| } | ||
|
|
||
| auto& nonDivisibleSplitInfo() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should non-const references in this header be private? Maybe we should check this as a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as GpuLower is passed as a const object after lowering, I think it should be fine, but yes, this class is growing and may need some refactoring.
| // supprot shift. If contig_id is not root, nothing is required to | ||
| // do for shift as shift-related domains are excluded from | ||
| // contig domains. | ||
| // supprot shift. If contig_id a merged non-root domain, nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: support shift. If a contig_id is a merged non-root domain, nothing
Agreed. We have contig IDs, non-divisible split IDs, and non-exact parallelized IDs. The last one is added in |
This is a second attempt to fix an indexing problem that is first reported in PR #1243. See that PR as well for the context of this problem.
Also, the repro of issue #1074 is now working (
FusionNonDivisibleSplit2).This PR addresses an index linearlization problem that was exposed by view and strided gather operations, both of which are based on rFactor. The root cause of the problem is non-divisible splits. When indexing with such splits, the domain that is split gets a larger extent due to the
ceilDivof the outer domain. This is propagated back from leaf to parent domains inIndexCompute::handle(Split*). It turns out that there are consistency problems when to propagate. The most sensible way is likely to completely stop doing so, but it also means intermediate domains have to be predicated if split by a non-divisible factor.Also, predicating such domains doesn't work if vectorization is involved. If an output domain of a split is vectorized, the split must be divisible. Divisibility is validated at the kernel launch time.
The main logic is in
NonDivisibleSplitInfo. It scans a fusion to locate anySplitexpression that must be either predicated or validated. If anIterDomainis run-time validated, it's not predicated as it's unnecessary (NonDivisibleSplitInfo::removeRedundancy()).getReferenceRootPredicatesis extended to generate predicates for those domains that may be split with a non-divisible factor. Currently, it finds root or merged contiguous domains to predicate. I extended that to include non-divisible split domains (getNonDivisibleDomainsToPredicate). Whileindex_compute.cppmay appear to have many changes, most of them don't change the existing logic. For example, the change ingetStartAndStopOffsetsis not to do adjustments for shift and gather operations when predicating non-divisible split domains.All of the C++ tests are passing, but
test_native_layer_norm_halfis failing due to a runtime validation of a non-divisible domain with vectorization. The domain to split is:This
i521domain is split by2and the inner domain of size 2 is vectorized. The test fails withT0being[17, 10, 10]since17 * 10 * 10 / 4isn't divisible by2. Without the runtime check, an alignment error occurs. Since @csarofeen mentioned he's working on changing scheduling algorithms for vectorization, I've done nothing on the failure yet.Closes #1074