Skip to content

[CuteDSL Grouped Gemm] Failure with Multiple grouped_gemm_nt_masked #1856

@wenscarl

Description

@wenscarl

There is a potential failure when running multiple instances of grouped_gemm_nt_masked.
When executing python test.py, the failure can occur, but if only a single test is run, it always passes.
In the failure case, the FlashInfer result is entirely zeros.
cc. @kaixih @yzh119

# test.py
from typing import Callable

import pytest
import torch
from flashinfer import fp4_quantize
#from flashinfer import (silu_and_mul_nvfp4_batched_quantize, nvfp4_batched_quantize)
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from torch.nn import functional as F

if torch.cuda.get_device_capability() < (10, 0):
    pytest.skip(
        reason="Nvfp4 Requires compute capability of 10 or above.",
        allow_module_level=True,
    )

kE2M1ToFloat = torch.tensor(
    [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)

FLOAT8_E4M3_MAX = 448.0
FLOAT4_E2M1_MAX = 6.0

def break_fp4_bytes(a, dtype):
    assert a.dtype == torch.uint8
    m, n = a.shape

    # Vectorized nibble processing
    a_flat = a.flatten()
    high = (a_flat & 0xF0) >> 4  # Upper nibbles
    low = a_flat & 0x0F  # Lower nibbles

    # Combine nibbles for batch processing
    combined = torch.stack((low, high), dim=1).flatten()

    # Vectorized sign and magnitude extraction
    signs = (combined & 0x08).to(torch.bool)  # Sign bits
    abs_vals = (combined & 0x07).to(torch.long)  # Magnitude indices

    # Device-aware lookup and sign application
    kE2M1 = kE2M1ToFloat.to(device=a.device)
    values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

    # Reshape to final form
    return values.reshape(m, n * 2).to(dtype=dtype)

def get_cute_dtype(input: torch.Tensor) -> str:
    if input.dtype == torch.bfloat16:
        return "bfloat16"
    elif input.dtype == torch.float16:
        return "float16"
    elif input.dtype == torch.float32:
        return "float32"
    else:
        raise ValueError(f"Unsupported cute dtype {input.dtype}")

def nvfp4_batched_quantize_proxy(x, sf, mask=None):
    b, m, n = x.shape
    out_shape = (b, m, n//2)
    padded_m = (m + 127 ) // 128 * 128
    n_by_16 = n // 16
    padded_n = (n_by_16  + 3 ) // 4 * 4
    outsf_shape = (b, padded_m * padded_n)# _compute_swizzled_layout_sf_size(m, n // 16, 128))
    out = torch.zeros(out_shape, dtype=torch.uint8, device=x.device)
    out_sf = torch.zeros(outsf_shape, dtype=torch.uint8, device=x.device)
    import pdb
#    pdb.set_trace()
    for i in range(b):
        single_out, single_scale = fp4_quantize(x[i], sf[i], 16, False, True)
        out[i]=single_out
        out_sf[i] = single_scale.view(-1)
    return out, out_sf

def scaled_fp4_grouped_quant(
    input_tensor: torch.Tensor,
    input_global_scale: torch.Tensor,
    mask: torch.Tensor,
    apply_silu: bool = False,
):
    """
    Unified wrapper around nvfp4_batched_quantize and 
    silu_and_mul_nvfp4_batched_quantize for flashinfer grouped GEMM.

    Args:
        input_tensor (Tensor): 
            - Shape (l, m, k) if apply_silu=False
            - Shape (l, m, k*2) if apply_silu=True
        input_global_scale (Tensor): Shape (l,)
        mask (Tensor): Mask tensor, broadcastable
        apply_silu (bool): If True, use silu_and_mul quantization

    Returns:
        output (Tensor): Quantized tensor, logical shape (m, k//2, l)
        output_scales (Tensor): Blockscale tensor, logical shape (32, 4, rm, 4, rk, l)
    """
    device = input_tensor.device
    l, m, k_like = input_tensor.shape

    if apply_silu:
        # input_tensor is (l, m, k//2)
        k = k_like // 2
    else:
        # input_tensor is (l, m, k)
        k = k_like

    sf_vec_size = 16
    assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."

    scale_k = k // sf_vec_size
    padded_k = (scale_k + (4 - 1)) // 4 * 4
    padded_m = (m + (128 - 1)) // 128 * 128

    # --- core quantization call ---
    if apply_silu:
        pass
#        aq, aq_sf = silu_and_mul_nvfp4_batched_quantize(
#            input_tensor,
#            mask,
#            input_global_scale,
#        )
    else:
        aq, aq_sf = nvfp4_batched_quantize_proxy(
            input_tensor,
            input_global_scale,
            mask=mask,
        )

    # --- re-layout quantized tensor ---
    # physical (l, m, k//2) -> logical (m, k//2, l)
    output = aq.permute(1, 2, 0)

    # --- re-layout blockscales ---
    # physical (l, rm, rk, 32, 4, 4) -> logical (32, 4, rm, 4, rk, l)
    output_scales = aq_sf.view(torch.float8_e4m3fn).view(
        l, padded_m // 128, padded_k // 4, 32, 4, 4)
    output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)

    return output, output_scales

def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
    m_tiles = (m + 128 - 1) // 128
    f = block_size * 4
    k_tiles = (k + f - 1) // f
    tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
    tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
    out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
    return out[0:m, 0:k]


def dequantize_nvfp4_to_dtype(
    tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
    """Dequantize the fp4 tensor back to high precision."""
    # Two fp4 values are packed into one uint8.
    assert tensor_fp4.dtype == torch.uint8
    m, packed_k = tensor_fp4.shape
    k = packed_k * 2
    tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
    tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
    tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
    tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
    tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale

    # scale the tensor
    out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
    return out.to(dtype=dtype)



def compute_routing(router_logits: torch.Tensor, top_k: int):
    routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
    routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
    routing_weights = routing_weights.float()
    return routing_weights, selected_experts


def prepare_inputs(
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    num_experts: int,
    topk: int,
):
    routing_weights, topk_idx = compute_routing(router_logits, topk)

    masked_m = []
    for i in range(num_experts):
        mask = topk_idx.view(-1) == i
        masked_m.append(mask.sum())

    masked_m = torch.tensor(masked_m, dtype=torch.int32)
    hidden_states_3d = torch.empty(
        (num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
    )
    for i in range(num_experts):
        hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]

    return hidden_states_3d, masked_m, topk_idx, routing_weights


MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 1024, 1536),
    (2, 3072, 1024),
    (2, 3072, 1536),
    (64, 1024, 1024),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (64, 2048, 1024),
    (224, 1024, 1024),
    (224, 1024, 1536),
]

def flashinfer_cutedsl_grouped_gemm_nt_masked(
    hidden_states: torch.Tensor,  # 3d
    input_global_scale: torch.Tensor,  # (l,)
    weights: torch.Tensor,
    w_global_scale: torch.Tensor,  # (l,)
    masked_m: torch.Tensor,
):
    # hidden_states: [l, m, k]
    # weights: [l, n, k]
    aq, aq_sf = scaled_fp4_grouped_quant(
        hidden_states,
        input_global_scale,
        masked_m.to(hidden_states.device),
    )
    num_experts, n, k = weights.shape
    bq, bq_sf = scaled_fp4_grouped_quant(
        weights,
        w_global_scale,
        torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
    )

    out = torch.zeros(
        (num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
    )
    out = out.permute(1, 2, 0)  # requirement of kernel
    sf_vec_size = 16
    ab_dtype = "float4_e2m1fn"
    sf_dtype = "float8_e4m3fn"
    c_dtype = "bfloat16"
    alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
        1, 1, num_experts
    )

    def get_cute_dtype(input: torch.Tensor) -> str:
        if input.dtype == torch.bfloat16:
            return "bfloat16"
        elif input.dtype == torch.float16:
            return "float16"
        elif input.dtype == torch.float32:
            return "float32"
        else:
            raise ValueError(f"Unsupported cute dtype {input.dtype}")

    grouped_gemm_nt_masked(
        (aq, aq_sf),
        (bq, bq_sf),
        out,
        masked_m.to(aq.device),
        ab_dtype=ab_dtype,
        sf_dtype=sf_dtype,
        c_dtype=c_dtype,
        sf_vec_size=sf_vec_size,
        alpha=alpha,
        alpha_dtype=get_cute_dtype(alpha),
    )

    return out

@pytest.mark.parametrize(
    "bs, hidden_dim, inter_dim, topk", [(2, 128, 512, 2), (16,128,512,2)]
)
@torch.inference_mode()
def test_grouped_gemm_nt_masked(
    bs: int, hidden_dim: int, inter_dim: int, topk: int
) -> None:
    torch.manual_seed(42)
    B = bs
    D = hidden_dim
    N = inter_dim
    num_experts = 8
    hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
    weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
    router_logits = torch.randn(B, num_experts, dtype=torch.float32)

    hidden_states_expanded = (
        hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    )
    hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
        hidden_states_expanded, router_logits, num_experts, topk
    )

    # reference
    out = torch.zeros(
        (B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device
    )
    for i in range(num_experts):
        mask = topk_idx.view(-1) == i
        if mask.sum():
            lhs = hidden_states_expanded[mask]
            rhs = weights[i]
            a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device)
            b_amax = rhs.abs().max().to(torch.float32).to(weights.device)
            a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
            b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax

            lhsq, lhsq_sf = fp4_quantize(
                lhs,
                a_gs,
            )
            rhsq, rhsq_sf = fp4_quantize(
                rhs,
                b_gs,
            )

            lhs_in_dtype = dequantize_nvfp4_to_dtype(
                lhsq,
                lhsq_sf,
                a_gs,
                dtype=hidden_states.dtype,
                device=hidden_states.device,
                block_size=16,
            )

            rhs_in_dtype = dequantize_nvfp4_to_dtype(
                rhsq,
                rhsq_sf,
                b_gs,
                dtype=hidden_states.dtype,
                device=hidden_states.device,
                block_size=16,
            )
            out[mask] = lhs_in_dtype @ rhs_in_dtype.t()

    a_amax = (
        hidden_states_3d.abs()
        .amax(dim=(1, 2))
        .to(torch.float32)
        .to(hidden_states.device)
    )
    b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
    a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
    b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
#    for _ in range(10):
#      out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
#          hidden_states_3d[:2].to(hidden_states.device), a_gs, weights, b_gs, masked_m[:2]
#      )
    out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
        hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
    )

    # re-pack out into [num_experts, max_m, n]
    out_ref = torch.zeros(
        (num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype
    )
    expert_slot = [0] * num_experts
    for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
        out_ref[expert_id, expert_slot[expert_id], :] = out[i]
        expert_slot[expert_id] += 1

    # Note: just to compare the masked position due to cutedsl may write nan
    # into unmasked position.
    for i in range(num_experts):
        #print(f"fi:{out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]]}")
        #print(f"ref:{out_ref.to(out_flashinfer.device)[i, : masked_m[i]]}")
        r = torch.isnan(out_flashinfer.permute(2,0,1)[i, : masked_m[i]]).any()
        print(f"has nan:{r}")
 #       torch.testing.assert
        assert r.item() is False



if __name__ == "__main__":
    test_grouped_gemm_nt_masked(2, 128, 256, 2)
    test_grouped_gemm_nt_masked(16, 128, 512, 4)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions