Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add turing mma support and test #1643

Merged
merged 73 commits into from
May 24, 2022
Merged

Add turing mma support and test #1643

merged 73 commits into from
May 24, 2022

Conversation

shmsong
Copy link

@shmsong shmsong commented May 2, 2022

Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface.

Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well, very minor comments and some questions. Mostly just wanted to refresh my knowledge in the area. Always really exciting to see the MMA scheduling.

torch/csrc/jit/codegen/cuda/runtime/memory.cu Show resolved Hide resolved
#if (__CUDA_ARCH__ < 800)
const unsigned thread_id = threadIdx.x;
// Upper half warp has 8 bytes offset from aligned in .x2 option
// of ldmatrix. Currently no support for .x1 so assume always
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason not to add x1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add .x1 in a follow up. I currently didn't yet add the smaller mma tiles that .x1 would pair with, they are not immediately useful yet in large CTA tile kernels.

// [M,K,N] -> [M,N,K]
tv2c->reorder({{-2, -1}, {-1, -2}});
tv2c->applyMmaSwizzle(
mma_builder.operand(MmaOptions::Operand::NotOperand).build());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should NotOperand be renamed to Accumulator?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I was not very sure what'd be a good name, since it's for Accumulator and Bias or any other epilog terms. But Accumulator does sound better. Renamed. Thanks.


auto tv2c = tv2->cacheBefore();

// [K,M,N] -> [N,M,K]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The syntax here is [broadcast, non-accumulation dimension, accumulation dimension] for both operands, correct? This seems to be what's consistently used on the inner tile for the operands.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that'd be the case for Turing and Ampere mma's, and the broadcast dimension could either be present or not, depending on where the shared mem tensor is placed. Both cases are supported in the swizzler.


auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat));

TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does validator not reasonably work here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error turned out to be on the same level as the tolerance set by validator, so I just manually raised the tolerance by an order for now. Was thinking about re-visiting error tolerance in matmul a bit in a follow up.

Base automatically changed from ampere_mma_op to devel May 23, 2022 23:50
@shmsong shmsong merged commit 5e6a8da into devel May 24, 2022
@shmsong shmsong deleted the turing_mma_op branch May 24, 2022 06:38
malfet pushed a commit to pytorch/pytorch that referenced this pull request Jun 8, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440

Commits that's actually in this PR from the csarofeen branch
```
* dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f Fix missing cooperative launch (#1726)
* dc670a2 Async gmem copy support on sm80+ (#1619)
* 5e6a8da Add turing mma support and test (#1643)
* d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39 Mma op integration on ampere (#1440)
* fade8da patch python test for bfloat16 (#1724)
* 8fbd0b1 Fine-grained kernel profiling (#1720)
* 77c1b4f Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b More precise concretization analysis (#1719)
* f4d3630 Enable complex python tests (#1667)
* 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830 Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7 updating_ci_machine (#1718)
* 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453 Allow using nvFuser on CUDA extension (#1701)
* 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676)
```
Pull Request resolved: #78244
Approved by: https://github.com/csarofeen, https://github.com/malfet
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Jun 8, 2022
Summary:
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440

Commits that's actually in this PR from the csarofeen branch
```
* dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f Fix missing cooperative launch (#1726)
* dc670a2 Async gmem copy support on sm80+ (#1619)
* 5e6a8da Add turing mma support and test (#1643)
* d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39 Mma op integration on ampere (#1440)
* fade8da patch python test for bfloat16 (#1724)
* 8fbd0b1 Fine-grained kernel profiling (#1720)
* 77c1b4f Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b More precise concretization analysis (#1719)
* f4d3630 Enable complex python tests (#1667)
* 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830 Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7 updating_ci_machine (#1718)
* 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453 Allow using nvFuser on CUDA extension (#1701)
* 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676)
```

Pull Request resolved: #78244

Reviewed By: ejguan

Differential Revision: D36678948

Pulled By: davidberard98

fbshipit-source-id: 0ccde965acbd31da67d99c6adb2eaaa888948105
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen/pytorch#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen/pytorch#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen/pytorch#1440

Commits that's actually in this PR from the csarofeen branch
```
* dd2325294e236c5082c642819a1103bcfe4561a3 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f446355a2d276bac8272e7aa8b5bb6b1f0 Fix missing cooperative launch (#1726)
* dc670a226cbe52be46cecef47001f38bf9a09433 Async gmem copy support on sm80+ (#1619)
* 5e6a8dab5a71aefe0548bbfa15d1a93c556d23fe Add turing mma support and test (#1643)
* d6d6b7d3f10dd91dafa4cdbd5e460bbb38173af4 Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39150c6d80e0f9f767d56654714a2e8a927 Mma op integration on ampere (#1440)
* fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724)
* 8fbd0b18743a72ac10478857c3d2351204375685 Fine-grained kernel profiling (#1720)
* 77c1b4fa633f9e631d267923f4537336fa328939 Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b97bebefc94199bb4a53423ede32b55451 More precise concretization analysis (#1719)
* f4d3630ed54d7069dd377a64be1f91013b285b66 Enable complex python tests (#1667)
* 4ceeee509774cc2ce6c834a4dc1e313f71d94503 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70faf218e86d2c78dbd3874b175a3b0a203 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830d5def65dadfe29d4edf52fc703369c84a Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718)
* 56585c58b1ff338704cafb0cd6be2b3d536bed5a Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701)
* 18bee67495454b9a79625799776e746bd5e81c4c Validate LOOP concrete IDs have complete IterDomains (#1676)
```
Pull Request resolved: pytorch/pytorch#78244
Approved by: https://github.com/csarofeen, https://github.com/malfet
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen/pytorch#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen/pytorch#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen/pytorch#1440

Commits that's actually in this PR from the csarofeen branch
```
* 939e6c9 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* e4a514b Fix missing cooperative launch (#1726)
* 1bb7b65 Async gmem copy support on sm80+ (#1619)
* 69354da Add turing mma support and test (#1643)
* 7ca0fa9 Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 8c5fb93 Mma op integration on ampere (#1440)
* fade8da55e60a118c5595378896d34b862b2fcc3 patch python test for bfloat16 (#1724)
* 1278624 Fine-grained kernel profiling (#1720)
* 34cb422 Adding dry run mode to skip arch dependent checks (#1702)
* 4c3cba4 More precise concretization analysis (#1719)
* 5a9ad9c Enable complex python tests (#1667)
* 8102c05 Minor bugfix in transform_rfactor.cpp (#1715)
* 2c0363c Separate root domain and rfactor domain in TransformPrinter (#1716)
* 1679226 Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7ae2cfd8fffad1e1d882ae7c50631211dc updating_ci_machine (#1718)
* acde15c Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453d3be0c11a5acb0fff3b3f36e19cfdaf81 Allow using nvFuser on CUDA extension (#1701)
* e57cc6b Validate LOOP concrete IDs have complete IterDomains (#1676)
```
Pull Request resolved: pytorch/pytorch#78244
Approved by: https://github.com/csarofeen, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants