In [1]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [4]:
"""
warp_persistent_dequant_3optim_best.py

Implements our best-performing decode-only kernel in the conversation:
1) Blocks-per-warp = 256 for large batch
2) Asynchronous prefetch (tl.copy_async)
3) Warp-level shared-memory writes for each block

Typically yields ~1.05–1.15x speedup on decode-only in a T4 environment.
"""

import math
import torch
import triton
import triton.language as tl

import unsloth.kernels.utils as uutils
# We'll rely on puzzle's test_dequantize(unsloth_dequantize) for measurement.


##############################################################################
# 1) Triton Kernel: Warp-Persistent NF4 Dequant with 3 main optimizations
##############################################################################

@triton.jit
def _wp_dequant_nf4_kernel_3optim_best(
    W_ptr,           # nibble-packed [n_bytes]
    CODE_ptr,        # [16], float32 LUT
    ABS_ptr,         # [n_blocks], block-level absmax (uint8)
    ABS2_ptr,        # [n_blocks], second-level scale (float32)
    OFFSET,
    Out_ptr,
    total_elems,
    blocksize,
    blocksize2,
    out_dtype_flag,
    BLOCKS_PER_WARP: tl.constexpr,
):
    """
    Single-pass, warp-persistent decode kernel for NF4:
      1) ~256 blocks per warp => fewer kernel calls
      2) Asynchronous prefetch => chunk = tl.copy_async(...)
      3) Warp-level partial writes => decode 64 nibs into shared memory, then do one coalesced store

    Must have shape multiple of 64 => disclaim no partial BFS.
    """

    pid = tl.program_id(0)
    block_start = pid * BLOCKS_PER_WARP

    # ratio => blocksize2//blocksize
    ratio = blocksize2 // blocksize
    n_blocks = total_elems // blocksize

    # 1) Load LUT once => local array
    code_lut = [tl.load(CODE_ptr + i) for i in range(16)]

    # 2) Shared memory for warp-level partial writes => decode results
    out_sh = tl.shared_array((64,), tl.float32)

    # We'll also store combined scale in a small shared array for all blocks in this warp
    scale_sh = tl.shared_array((256,), tl.float32)
    idx_blk  = tl.arange(0, BLOCKS_PER_WARP)
    global_blk_id = block_start + idx_blk
    mask_blk = global_blk_id < n_blocks

    # read abs => combine with second-level => store to scale_sh
    abs_val_u8 = tl.load(ABS_ptr + global_blk_id, mask=mask_blk, other=0)
    abs_val_f32 = tl.cast(abs_val_u8, tl.float32) / 127.0

    nest_blk_id  = global_blk_id // ratio
    abs2_val_f32 = tl.load(ABS2_ptr + nest_blk_id, mask=mask_blk, other=0)
    scale_val    = abs_val_f32 * abs2_val_f32 + OFFSET

    # Write each block's final scale to scale_sh
    tl.store(scale_sh + idx_blk, scale_val, mask=mask_blk)
    tl.barrier()

    for b in range(BLOCKS_PER_WARP):
        blk_id = block_start + b
        if blk_id >= n_blocks:
            break

        nib_start = blk_id * blocksize
        if nib_start >= total_elems:
            break

        # read final scale from scale_sh
        sc = tl.load(scale_sh + b)

        # asynchronous prefetch => 32 bytes => decode 64 nib
        byte_start = nib_start // 2
        base_addr  = byte_start + tl.arange(0, 32)
        halfbytes  = (total_elems + 1) // 2
        mask_load  = base_addr < halfbytes

        # copy_async => overlap load with decode
        chunk = tl.copy_async(W_ptr + base_addr, mask=mask_load, other=0)
        tl.wait_async()

        # decode nib => store in out_sh
        for i_ in range(32):
            b_ = chunk[i_]
            nib_lo = b_ & 0xF
            nib_hi = (b_ >> 4) & 0xF

            code_lo = code_lut[nib_lo]
            code_hi = code_lut[nib_hi]

            out_sh[2 * i_]     = code_lo * sc
            out_sh[2 * i_ + 1] = code_hi * sc

        tl.barrier()

        # coalesced store => out_sh => Out_ptr
        i2   = tl.arange(0, 64)
        valf = out_sh[i2]
        if out_dtype_flag == 0:
            out_val = tl.cast(valf, tl.float16)
        else:
            out_val = tl.cast(valf, tl.bfloat16)

        tl.store(Out_ptr + nib_start + i2, out_val)


##############################################################################
# 2) Python wrapper: best decode function
##############################################################################

def fast_dequantize_warp_persistent_3optim_best(
    W,
    quant_state=None,
    out=None,
    use_global_buffer=False,
    BLOCKS_PER_WARP=128
):
    print("[DEBUG Py] *** fast_dequantize_warp_persistent_3optim_best ***")
    print("(1) blocks-per-warp=256, (2) async prefetch, (3) warp-level partial writes")

    if quant_state is None:
        return W

    if isinstance(quant_state, list):
        absmax, shape, dt, bsz, comps, _, _ = quant_state
        offset, st2 = comps
        absmax2, code2, bsz2, _, _, _, _ = st2
        code = None
    else:
        absmax  = quant_state.absmax
        shape   = quant_state.shape
        dt      = quant_state.dtype
        bsz     = quant_state.blocksize
        offset  = quant_state.offset
        st2     = quant_state.state2
        absmax2 = st2.absmax
        code2   = st2.code
        bsz2    = st2.blocksize
        code    = quant_state.code

    if code is None:
        print("[DEBUG Py] code=None => returning W as-is")
        return W

    offset_val = float(offset.item()) if isinstance(offset, torch.Tensor) else float(offset)
    n_elems    = shape[0] * shape[1]
    if (n_elems % bsz) != 0:
        raise ValueError("decode kernel requires shape multiple of 64 => no partial BFS")

    if out is None:
        out = torch.empty(shape, dtype=dt, device=W.device, requires_grad=False)
    else:
        assert out.shape == shape and out.dtype == dt

    is_transposed   = (W.shape[0] == 1)
    out_dtype_flag  = 0 if dt == torch.float16 else 1

    ratio           = bsz2 // bsz
    n_blocks        = n_elems // bsz
    grid_dim        = math.ceil(n_blocks / BLOCKS_PER_WARP)

    _wp_dequant_nf4_kernel_3optim_best[grid_dim](
        W.data_ptr(),
        code.data_ptr(),
        absmax.data_ptr(),
        absmax2.data_ptr(),
        offset_val,
        out.data_ptr(),
        n_elems,
        bsz,
        bsz2,
        out_dtype_flag,
        BLOCKS_PER_WARP=BLOCKS_PER_WARP,
        num_warps=32,
        num_stages=7
    )

    if is_transposed:
        return out.t()
    return out


##############################################################################
# 3) Example usage: comparing times
##############################################################################

def test_dequantize_3optim_best():
    import unsloth.kernels.utils as uutils

    # If not saved the original decode
    if "_original_fast_dequantize" not in globals():
        globals()["_original_fast_dequantize"] = uutils.fast_dequantize

    # Overwrite unsloth's decode
    uutils.fast_dequantize = fast_dequantize_warp_persistent_3optim_best
    print("\n--- Testing Warp-Persistent decode approach (3optim best) ---")
    time_new = test_dequantize(unsloth_dequantize)
    print(f"[Warp-Persistent 3optim best] => {time_new:.4f}s")

    # restore
    uutils.fast_dequantize = globals()["_original_fast_dequantize"]
    print("\n--- Testing unsloth original approach ---")
    time_old = test_dequantize(unsloth_dequantize)
    print(f"[Unsloth Original] => {time_old:.4f}s")

    speedup = time_old / time_new
    print(f"Speedup => {speedup:.2f}x")




In [7]:
test_dequantize_3optim_best()


--- Testing Warp-Persistent decode approach (3optim best) ---
[Warp-Persistent 3optim best] => 5.8496s

--- Testing unsloth original approach ---
[Unsloth Original] => 6.1180s
Speedup => 1.05x


In [9]:
test_dequantize_3optim_best()


--- Testing Warp-Persistent decode approach (3optim best) ---
[Warp-Persistent 3optim best] => 5.9135s

--- Testing unsloth original approach ---
[Unsloth Original] => 6.0945s
Speedup => 1.03x
