Skip to content

[QST] Using 2xSM load in MMA kernels with input transform stage #3284

@infinitron

Description

@infinitron

What is your question?
I'm trying to build a CollectiveMma with a custom input transformation step, I'm using the sm100_9xBF16_umma_builder.inl as the starting point. I'm stuck on the comment here

// Input transform kernel can not use TMA 2SM instructions.

The line below uses regular SM90 1SM loads instead of SM100_TMA_2SM_LOAD even when the kernel schedule tag specifies 2SM. The tiles loaded from gmem will be the same in both cases because the input matrix will be tiled and partitioned using the AtomThrID in the case of 1SM load instructions, which is equivalent to using the 2SM load instructions. Can anyone correct me if I'm wrong here?

For my use case, the transform requires changing the data types for A and B, and also extending the matrix B by a factor of two with a custom transform. So I can either use the 1SM instruction and split the loads appropriately to invoke the 2SM UMMA instruction, or just use the 2SM loads.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions