-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mma op integration on ampere #1440
Conversation
2b4f3b0
to
edd43d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good :-)
std::stringstream ss; | ||
|
||
ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << "," | ||
<< vec_size << ">*>(&" << gen(val) << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reinterpret cast seemed quite unreliable, are you sure you don't want to make sure we thunk to ptx if we can?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes eventually I think we want to explicitly specify which 4 registers to pass to ldmatrix, once we have the swizzle labeling part merged.
//! 2. direct output of a broadcast op following a ldmatrix op | ||
//! Returns true if the tv is an immediate output of ldmatrix op | ||
//! | ||
//! TODO: this check is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: what's the todo for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for eventually removing the pattern matching for ldmatrix. Completed the todo comment. Thanks.
tv->split(-1, 4); | ||
tv->split(-2, 2); | ||
|
||
// 0 1 2 3 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: 4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned up. Thanks!
if (options.operand == MmaOptions::Operand::A) { | ||
TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); | ||
// validation: | ||
TORCH_INTERNAL_ASSERT(canValidateIsInnerDim( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Asserts should have something printed unless it should be impossible to reach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added error message in these checks. Thanks.
Looks like some non-trivial changes have been made since I reviewed the last time. Let me take a look at them. |
// T2 = 0; // init for exp1 | ||
// if(pred) | ||
// T2 = T1 ... // exp1 | ||
// If we remove pred around expr1, as the way the pred removal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only happens when exp1 is a reduction-like op.
// T1[i] = T0[i] + ... | ||
// Note that we'd be able to reuse buffer of T0 for T1 but | ||
// if we initialze T1 we cannot do that and thus the | ||
// kernel would not fit in smaller devices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to disable in this case for now, but looks like for this kind of patterns we would want to just initialize T0
and eliminate the predicates for T1
and T2
, which would not interfere with the buffer reuse between T0
and T1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The holistic predicate/initialization analysis has been on the TODO list and quite a few TODO's in this PR actually boil down to that. I will try to get to that in follow ups.
@@ -213,7 +213,13 @@ void SyncMap::build(Fusion* fusion) { | |||
// If consumer is parallelized with this type but producer is | |||
// predicated redundant on this type. This parallel dimension | |||
// is a RAW dimension. See test: FusionSeriaSmemWriteParallelRead1/2 | |||
if (c_id != nullptr && producer_redundant_types.get(parallel_type)) { | |||
// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this specific to this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also please add a test for this specific pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean whenever we have a redundant op, we would insert a sync immediately after that? Then, what would happen if we have a chain of redundant exprs? In my naive quick thinking, it seems we should just add a sync when an expr chain becomes non redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we extract this change out of this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sync insertion wouldn't be right after the redundant op, which I believe currently only means redundant write. The usual RAW sync insertion rule still applies here, i.e. a sync right before they are used.
As on the comment and TODO, this sync wouldn't be needed if all use chains of this redundant write ends with another redundant shared/global write with the same redundant type. This case would need an additional traversal info somewhere.
As this issue currently lives in devel, just wanted to do a quick patch and then follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed from this PR and moved to #1684
Checked perf delta on V100. Looked like the change in this PR didn't affect norm perf. The delta seemed to be within noise range. Below are the top ones with slow down > 2%. (Sorted in decreasing order by %) <style> </style>
|
Agree, seems within noise go ahead and merge. |
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: csarofeen#1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440 Commits that's actually in this PR from the csarofeen branch ``` * dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f Fix missing cooperative launch (#1726) * dc670a2 Async gmem copy support on sm80+ (#1619) * 5e6a8da Add turing mma support and test (#1643) * d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39 Mma op integration on ampere (#1440) * fade8da patch python test for bfloat16 (#1724) * 8fbd0b1 Fine-grained kernel profiling (#1720) * 77c1b4f Adding dry run mode to skip arch dependent checks (#1702) * 151d95b More precise concretization analysis (#1719) * f4d3630 Enable complex python tests (#1667) * 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830 Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7 updating_ci_machine (#1718) * 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453 Allow using nvFuser on CUDA extension (#1701) * 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: #78244 Approved by: https://github.com/csarofeen, https://github.com/malfet
Summary: Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: csarofeen#1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440 Commits that's actually in this PR from the csarofeen branch ``` * dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f Fix missing cooperative launch (#1726) * dc670a2 Async gmem copy support on sm80+ (#1619) * 5e6a8da Add turing mma support and test (#1643) * d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39 Mma op integration on ampere (#1440) * fade8da patch python test for bfloat16 (#1724) * 8fbd0b1 Fine-grained kernel profiling (#1720) * 77c1b4f Adding dry run mode to skip arch dependent checks (#1702) * 151d95b More precise concretization analysis (#1719) * f4d3630 Enable complex python tests (#1667) * 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830 Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7 updating_ci_machine (#1718) * 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453 Allow using nvFuser on CUDA extension (#1701) * 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: #78244 Reviewed By: ejguan Differential Revision: D36678948 Pulled By: davidberard98 fbshipit-source-id: 0ccde965acbd31da67d99c6adb2eaaa888948105
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: csarofeen/pytorch#1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen/pytorch#1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen/pytorch#1440 Commits that's actually in this PR from the csarofeen branch ``` * dd2325294e236c5082c642819a1103bcfe4561a3 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f446355a2d276bac8272e7aa8b5bb6b1f0 Fix missing cooperative launch (#1726) * dc670a226cbe52be46cecef47001f38bf9a09433 Async gmem copy support on sm80+ (#1619) * 5e6a8dab5a71aefe0548bbfa15d1a93c556d23fe Add turing mma support and test (#1643) * d6d6b7d3f10dd91dafa4cdbd5e460bbb38173af4 Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39150c6d80e0f9f767d56654714a2e8a927 Mma op integration on ampere (#1440) * fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724) * 8fbd0b18743a72ac10478857c3d2351204375685 Fine-grained kernel profiling (#1720) * 77c1b4fa633f9e631d267923f4537336fa328939 Adding dry run mode to skip arch dependent checks (#1702) * 151d95b97bebefc94199bb4a53423ede32b55451 More precise concretization analysis (#1719) * f4d3630ed54d7069dd377a64be1f91013b285b66 Enable complex python tests (#1667) * 4ceeee509774cc2ce6c834a4dc1e313f71d94503 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70faf218e86d2c78dbd3874b175a3b0a203 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830d5def65dadfe29d4edf52fc703369c84a Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718) * 56585c58b1ff338704cafb0cd6be2b3d536bed5a Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701) * 18bee67495454b9a79625799776e746bd5e81c4c Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: pytorch/pytorch#78244 Approved by: https://github.com/csarofeen, https://github.com/malfet
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: csarofeen/pytorch#1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen/pytorch#1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen/pytorch#1440 Commits that's actually in this PR from the csarofeen branch ``` * 939e6c9 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * e4a514b Fix missing cooperative launch (#1726) * 1bb7b65 Async gmem copy support on sm80+ (#1619) * 69354da Add turing mma support and test (#1643) * 7ca0fa9 Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 8c5fb93 Mma op integration on ampere (#1440) * fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724) * 1278624 Fine-grained kernel profiling (#1720) * 34cb422 Adding dry run mode to skip arch dependent checks (#1702) * 4c3cba4 More precise concretization analysis (#1719) * 5a9ad9c Enable complex python tests (#1667) * 8102c05 Minor bugfix in transform_rfactor.cpp (#1715) * 2c0363c Separate root domain and rfactor domain in TransformPrinter (#1716) * 1679226 Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718) * acde15c Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701) * e57cc6b Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: pytorch/pytorch#78244 Approved by: https://github.com/csarofeen, https://github.com/malfet
This PR is a continuation of #1439, the focus is on extending the infrastructure to support mma operators on turing and ampere arch, including: