Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trivial predicate is causing a 30% slowdown for matmul with grid swizzle #95

Closed
5 tasks done
zasdfgbnm opened this issue Mar 29, 2023 · 2 comments · Fixed by #106
Closed
5 tasks done

Trivial predicate is causing a 30% slowdown for matmul with grid swizzle #95

zasdfgbnm opened this issue Mar 29, 2023 · 2 comments · Fixed by #106
Assignees

Comments

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Mar 29, 2023

For the example in FusionAmpereSwizzle_CUDA, the generated code contains trivial predicates:

    #pragma unroll
    for(nvfuser_index_t i653 = 0; i653 < 4; ++i653) {
      int i10749;
      i10749 = 32 * i653;
      #pragma unroll
      for(nvfuser_index_t i654 = 0; i654 < 8; ++i654) {
        if (((nvfuser_index_t)blockIdx.x) < ((ceilDiv(T1.size[1], 128)) * 4)) {
          Ampere::M16N8K16TN<16>(
            reinterpret_cast<Array<float,4,4>*>(&T5[(i10749 + (2 * i654))]),
            &(reinterpret_cast<Array<__half,8,8>*>(&T2)[i653]),
            &(reinterpret_cast<Array<__half,4,4>*>(&T3)[i654]));
        }
      }
    }

where ((nvfuser_index_t)blockIdx.x) < ((ceilDiv(T1.size[1], 128)) * 4) is trivial because the rhs of < is identical to gridDim.x. We should simplify this trivial predicate.

On RTX 3090, the perf with and without that trivial predicate is 20.8374 ms vs 16.1956 ms

Tasks

@zasdfgbnm zasdfgbnm self-assigned this Mar 29, 2023
@zasdfgbnm
Copy link
Collaborator Author

cc @mmigdal-nv @drzejan2

@mmigdal-nv
Copy link
Collaborator

Thank you a lot for these fixes!

zasdfgbnm added a commit that referenced this issue Apr 1, 2023
This PR is stacked on #105, it fixes
#95.

There are mainly two changes effective:
- In `lower_scalar_hoist.cpp`, add assumptions about parallel indices
- In `predicate_compute.cpp`, check if extent is the same as parallel
dimension

Only one of the above changes is sufficient to fix the bug, but I did
both to be double safe.

cc: @mmigdal-nv @drzejan2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants