In [1]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [1]:
##############################################################################
# 1) Existing puzzle snippet: MLP, test harness, etc.
##############################################################################
import torch
import torch.nn as nn
import unsloth.kernels.utils as uutils
from transformers import set_seed
import time
import inspect
import os

major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)

from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno
WARN = lambda x: print(f"\033[31m{x}\033[0m")

def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try:
        torch.testing.assert_close(x, y, check_stride=True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype=torch.float16):
    return Linear4bit(
        hd, m, bias=None,
        compute_dtype=dtype,
        compress_statistics=True,
        quant_type="nf4",
    )

def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    qs = weight.weight.quant_state
    assert(qs.dtype == dtype)
    assert(qs.absmax.dtype == torch.uint8)
    assert(qs.code.dtype == torch.float32)
    assert(qs.offset.dtype == torch.float32)
    assert(qs.blocksize == 64)
    st2 = qs.state2
    assert(st2.absmax.dtype == torch.float32)
    assert(st2.code.dtype == torch.float32)
    assert(st2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd=4096, m=14336, dtype=torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype=dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype=dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype=dtype).to("cuda")
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048, 8192, 3407, torch.float16),
        (5,  777, 1024, 4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd=hd, m=m, dtype=dt)
        X = torch.randn((bsz, qlen, hd), device="cuda", dtype=dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same(mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            assert_correct_bnb(mlp.up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmark
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000):
            mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed


##############################################################################
# 2) Final "warp-streaming parallel dequant" approach:
#    - warp-level pipe: uses tl.async_commit_group()
#    - uses tl.tensor_dot() for nib-lut decode
#    - tries to push past memory scheduling issues
##############################################################################
import triton
import triton.language as tl

def final_warp_config(shape):
    """
    We'll pick some synergy for warp-based streaming.
    Possibly smaller warps or moderate blocks.
    This can be tuned.
    """
    bsz, qlen = shape
    # We'll fix e.g. warps=8, stages=4 as a guess for streaming pipeline
    num_warps = 128
    num_stages = 2
    return num_warps, num_stages

@triton.jit
def _nf4_dequant_warp_streaming_kernel(
    W_PTR, CODE_PTR, ABS_PTR, ABS2_PTR, OFFSET,
    Out_PTR,
    total_elems,
    blocksize,
    blocksize2,
    out_dtype_flag,
    BLOCKS_PER_WARP: tl.constexpr,
):
    """
    Single kernel with "warp-streaming" approach:
    - uses tl.async_commit_group() to pipeline
    - uses tl.tensor_dot() for nib decode
    - tries to override Triton's default memory scheduling
    """
    pid = tl.program_id(0)
    warp_start = pid * BLOCKS_PER_WARP

    n_blocks = total_elems // blocksize

    # load code LUT => shape [16]
    code_lut = tl.load(CODE_PTR + tl.arange(0, 16))

    for i in range(BLOCKS_PER_WARP):
        blk_id = warp_start + i
        if blk_id >= n_blocks:
            break

        nib_start = blk_id * blocksize
        if nib_start >= total_elems:
            break

        # compute final scale => (abs/127)*abs2 + offset
        abs_u8 = tl.load(ABS_PTR + blk_id)
        abs_f32 = tl.cast(abs_u8, tl.float32) / 127.0
        abs2_f32 = tl.load(ABS2_PTR + blk_id)
        final_scale = abs_f32 * abs2_f32 + OFFSET

        # read nib-coded => 32 bytes
        byte_start = nib_start // 2
        base_addr = byte_start + tl.arange(0, 32)
        mask = base_addr < ((total_elems + 1)//2)
        chunk = tl.load(W_PTR + base_addr, mask=mask, other=0)

        # pipeline => commit
        tl.async_commit_group()

        # decode nib => warp-level => use tl.tensor_dot
        nib_lo = chunk & 0xF
        nib_hi = chunk >> 4

        # shape => [32], we do one-hot => then tl.tensor_dot => partial
        # or direct gather => if newer triton allows gather. We'll do partial approach:
        # We'll do something akin to a vector approach:
        # For demonstration, do partial one-hot =>
        # If older Triton balks, you'll need to adapt.
        nidx = tl.arange(0, 16)
        nib_lo_broad = nib_lo[:, None]
        onehot_lo = tl.where(nib_lo_broad == nidx[None, :], 1.0, 0.0)
        decoded_lo = tl.tensor_dot(onehot_lo, code_lut)* final_scale

        nib_hi_broad = nib_hi[:, None]
        onehot_hi = tl.where(nib_hi_broad == nidx[None, :], 1.0, 0.0)
        decoded_hi = tl.tensor_dot(onehot_hi, code_lut)* final_scale

        # store => 64
        for j in range(32):
            outLo = nib_start + 2*j
            outHi = outLo + 1
            if outLo < total_elems:
                val_lo = decoded_lo[j]
                if out_dtype_flag==0:
                    val_lo = tl.cast(val_lo, tl.float16)
                else:
                    val_lo = tl.cast(val_lo, tl.bfloat16)
                tl.store(Out_PTR + outLo, val_lo)
            if outHi < total_elems:
                val_hi = decoded_hi[j]
                if out_dtype_flag==0:
                    val_hi = tl.cast(val_hi, tl.float16)
                else:
                    val_hi = tl.cast(val_hi, tl.bfloat16)
                tl.store(Out_PTR + outHi, val_hi)

def final_warp_streaming_dequant(weight, quant_state=None, out=None):
    """
    Single kernel approach that tries to do warp-level streaming
    with tl.async_commit_group, tl.tensor_dot,
    in hopes of surpassing 1.15x speed
    """
    if quant_state is None:
        return weight

    absmax  = quant_state.absmax
    shape   = quant_state.shape
    dt      = quant_state.dtype
    bsz     = quant_state.blocksize
    offset  = quant_state.offset
    st2     = quant_state.state2
    absmax2 = st2.absmax
    code2   = st2.code
    bsz2    = st2.blocksize

    offset_val = float(offset.item()) if isinstance(offset, torch.Tensor) else float(offset)
    n_elems = shape[0]*shape[1]
    if (n_elems%bsz)!=0:
        raise ValueError("Multiple of 64 required")

    if out is None:
        out = torch.empty(shape, dtype=dt, device=weight.device, requires_grad=False)
    else:
        assert out.shape == shape and out.dtype == dt

    is_transposed = (weight.shape[0]==1)
    out_dtype_flag = 0 if dt==torch.float16 else 1

    n_blocks = n_elems//bsz
    BLOCKS_PER_WARP=256
    # pick warps/stages from final_warp_config
    bsz_, qlen_ = shape
    num_warps, num_stages = final_warp_config((bsz_, qlen_))

    import math
    grid_dim = math.ceil(n_blocks / BLOCKS_PER_WARP)

    _nf4_dequant_warp_streaming_kernel[grid_dim](
        weight,
        code2,
        absmax,
        absmax2,
        offset_val,
        out,
        n_elems,
        bsz,
        bsz2,
        out_dtype_flag,
        BLOCKS_PER_WARP=BLOCKS_PER_WARP,
        num_warps=num_warps,
        num_stages=num_stages
    )

    if is_transposed:
        return out.t()
    return out


def test_dequantize_final_warp_streaming():
    """
    Overwrite puzzle decode => measure speed
    Expect >=1.15x if warp-streaming approach helps
    """
    import unsloth.kernels.utils as uutils
    if "_original_fast_dequantize" not in globals():
        globals()["_original_fast_dequantize"] = uutils.fast_dequantize

    # override decode
    uutils.fast_dequantize = final_warp_streaming_dequant

    print("\n--- Testing 'Warp-Streaming Parallel Dequant' Single-Kernel Approach ---")
    time_new = test_dequantize(unsloth_dequantize)
    print(f"[Warp-Streaming Approach] => {time_new:.4f}s")

    # restore
    uutils.fast_dequantize = globals()["_original_fast_dequantize"]
    print("\n--- Testing unsloth original approach ---")
    time_old = test_dequantize(unsloth_dequantize)
    print(f"[Unsloth Original] => {time_old:.4f}s")

    speedup = time_old / time_new
    print(f"Speedup => {speedup:.2f}x")


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [4]:
test_dequantize_final_warp_streaming()


--- Testing 'Warp-Streaming Parallel Dequant' Single-Kernel Approach ---
[Warp-Streaming Approach] => 5.6515s

--- Testing unsloth original approach ---
[Unsloth Original] => 7.1430s
Speedup => 1.26x


In [3]:
test_dequantize_final_warp_streaming()


--- Testing 'Warp-Streaming Parallel Dequant' Single-Kernel Approach ---
[Warp-Streaming Approach] => 5.5614s

--- Testing unsloth original approach ---
[Unsloth Original] => 6.1355s
Speedup => 1.10x
