<a href="https://colab.research.google.com/github/csalnav2/QdotCS/blob/master/nf4Triton1.02x.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# =============================
# 1) Full environment reset
# =============================
!pip uninstall -y torch torchvision torchaudio triton xformers bitsandbytes unsloth unsloth_zoo fastai cut_cross_entropy

# =============================
# 2) Install nightly PyTorch (>=2.6.0 dev) + matching Triton
# =============================
!pip install --no-cache-dir --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
!pip install --no-cache-dir triton>=3.1.0

# =============================
# 3) Install extras
# =============================
!pip install --no-deps xformers==0.0.29 bitsandbytes accelerate peft trl
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth_zoo
!pip install --no-deps unsloth
!pip install tyro

import torch
print("Torch:", torch.__version__)

import triton
print("Triton:", triton.__version__)

# If Colab complains "You must restart runtime," do it and re-run.


Found existing installation: torch 2.5.1+cu124
Uninstalling torch-2.5.1+cu124:
  Successfully uninstalled torch-2.5.1+cu124
Found existing installation: torchvision 0.20.1+cu124
Uninstalling torchvision-0.20.1+cu124:
  Successfully uninstalled torchvision-0.20.1+cu124
Found existing installation: torchaudio 2.5.1+cu124
Uninstalling torchaudio-2.5.1+cu124:
  Successfully uninstalled torchaudio-2.5.1+cu124
Found existing installation: triton 3.1.0
Uninstalling triton-3.1.0:
  Successfully uninstalled triton-3.1.0
[0mFound existing installation: fastai 2.7.18
Uninstalling fastai-2.7.18:
  Successfully uninstalled fastai-2.7.18
[0mLooking in indexes: https://download.pytorch.org/whl/nightly/cu121
Collecting torch
  Downloading https://download.pytorch.org/whl/nightly/cu121/torch-2.6.0.dev20241112%2Bcu121-cp311-cp311-linux_x86_64.whl (768.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m768.0/768.0 MB[0m [31m332.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch

In [2]:
import torch
import torch.nn as nn
import time
import inspect
from transformers import set_seed
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN

# For unsloth-based NF4 decode
from unsloth.kernels.utils import fast_dequantize

# For PEFT-based NF4 decode
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", device)

def assert_same(x, y, dt):
    if x.dtype != dt:
        raise RuntimeError(f"dtype mismatch: got {x.dtype}, expected {dt}")
    torch.testing.assert_close(x, y, check_stride=True)

def bnb_Linear4bit(hd, m, dt=torch.float16):
    return Linear4bit(hd, m, bias=None, compute_dtype=dt, compress_statistics=True, quant_type="nf4")

def assert_bnb_state(w, dt):
    assert w.weight.dtype == torch.uint8
    s = w.weight.quant_state
    assert s.dtype == dt
    assert s.absmax.dtype == torch.uint8
    assert s.code.dtype == torch.float32
    assert s.offset.dtype == torch.float32
    assert s.blocksize == 64
    assert s.state2.absmax.dtype == torch.float32
    assert s.state2.code.dtype == torch.float32
    assert s.state2.blocksize == 256

class MLP(nn.Module):
    def __init__(self, hd=4096, m=14336, dt=torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dt).to(device)
        self.up_proj   = bnb_Linear4bit(hd, m, dt).to(device)
        self.down_proj = bnb_Linear4bit(m, hd, dt).to(device)
        # Force correct quant_state dtype
        self.gate_proj.weight.quant_state.dtype = dt
        self.up_proj.weight.quant_state.dtype   = dt
        self.down_proj.weight.quant_state.dtype = dt
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    return mlp.act_fn(gate) * up @ fx(mlp.down_proj).t()

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.up_proj).t();   torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a,b,c

def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def test_dequantize(dequant_fx):
    e = 0
    # Some test configs
    configs = [
      (2, 3333, 2048, 8192, 3407, torch.float16),
      (5,  777, 1024, 4096, 3409, torch.bfloat16),
      (3, 2048, 4096, 14336,3408, torch.bfloat16),
    ]
    for (bsz, ql, hd, m, seed, dt) in configs:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd, m, dt)
        X   = torch.randn((bsz, ql, hd), device=device, dtype=dt)
        torch.cuda.synchronize()
        # Warmup checks
        for _ in range(2):
            out_manual = mlp_forward(X, mlp, dequant_fx)
            out_model  = mlp(X)
            assert_same(out_manual, out_model, dt)
            assert_bnb_state(mlp.up_proj, dt)
            assert_bnb_state(mlp.gate_proj, dt)
            assert_bnb_state(mlp.down_proj, dt)
            a,b,c = mlp_dequantize(X, mlp, dequant_fx)
            A,B,C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a,A, dt)
            assert_same(b,B, dt)
            assert_same(c,C, dt)
        # Benchmark
        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(1000):
            mlp_dequantize(X, mlp, dequant_fx)
        e += (time.time() - t0)
    return e

# Test unsloth + peft
if device == "cuda":
    e_unsloth = test_dequantize(unsloth_dequantize)
    print(f"[INFO] unsloth_dequantize total time: {e_unsloth:.4f}s")
    e_peft = test_dequantize(peft_dequantize)
    print(f"[INFO] peft_dequantize total time: {e_peft:.4f}s")
else:
    print("[INFO] CPU environment, skipping NF4 bitsandbytes tests.")


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0.dev20241112+cu121)
    Python  3.11.11 (you have 3.11.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!
Running on device: cuda
[INFO] unsloth_dequantize total time: 5.2716s
[INFO] peft_dequantize total time: 5.4960s


In [1]:
!pip uninstall -y torch triton bitsandbytes unsloth
!pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
!pip install --no-cache-dir triton>=3.1.0
!pip install bitsandbytes
!pip install --no-deps unsloth
!pip install --no-deps peft


Found existing installation: torch 2.6.0.dev20241112+cu121
Uninstalling torch-2.6.0.dev20241112+cu121:
  Successfully uninstalled torch-2.6.0.dev20241112+cu121
Found existing installation: triton 3.2.0
Uninstalling triton-3.2.0:
  Successfully uninstalled triton-3.2.0
[0mLooking in indexes: https://download.pytorch.org/whl/nightly/cu121
Collecting torch
  Using cached https://download.pytorch.org/whl/nightly/cu121/torch-2.6.0.dev20241112%2Bcu121-cp311-cp311-linux_x86_64.whl (768.0 MB)
Installing collected packages: torch
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
unsloth-zoo 2025.2.7 requires cut_cross_entropy, which is not installed.
unsloth-zoo 2025.2.7 requires triton; platform_system == "Linux", which is not installed.
xformers 0.0.29 requires torch==2.5.1, but you have torch 2.6.0.dev20241112+cu121 which is incompatible.
unsloth-zoo 2025.2.7 req

In [3]:
# ================================ Cell start ================================
import torch
import triton
import triton.language as tl
from triton import jit

# 1) Minimal kernel with *no* parameters or shape arguments.
#    We literally do arange(0, 512) => a compile-time constant: 512
@jit
def kernel_512(X_ptr, Y_ptr):
    # Hard-coded literal => 512 => must be recognized as a compile-time constant
    idx = tl.arange(0, 512)
    x_val = tl.load(X_ptr + idx)
    out   = x_val + 1.0
    tl.store(Y_ptr + idx, out)

# 2) Attempt to run
try:
    x = torch.randn(512, device="cuda")
    y = torch.empty_like(x)
    # Launch with grid=(1,) => single block
    kernel_512[(1,)](x, y)
    print("[INFO] Triton kernel_512 succeeded!")
    print("[INFO] y[:10] =", y[:10])
except Exception as e:
    print("[ERROR] kernel_512 failed =>", e)
# ================================ Cell end =================================


[INFO] Triton kernel_512 succeeded!
[INFO] y[:10] = tensor([-0.3428,  1.0706,  0.2601,  0.5595,  1.5209, -0.2726,  1.5390,  1.3119,
         1.6588,  1.2323], device='cuda:0')


In [2]:
!pip install transformers accelerate bitsandbytes peft



In [4]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model, TaskType

# Ensure we have a GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[INFO] device:", device)


[INFO] device: cuda


In [None]:
import torch
import torch.nn as nn
import time

from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN

################################################################################
# 1) MLP using bitsandbytes 4-bit layers
################################################################################
def bnb_Linear4bit(in_features, out_features, dtype=torch.float16):
    """
    We'll forcibly interpret arguments as (in_features, out_features)
    though bitsandbytes might flatten them.
    """
    return Linear4bit(
        in_features, out_features,
        bias=None,
        compute_dtype=dtype,
        compress_statistics=True,
        quant_type="nf4",
    )

class MLP(nn.Module):
    def __init__(self, hd=2048, m=8192, dtype=torch.float16):
        super().__init__()
        # Gate & Up => (hd->m), down => (m->hd)
        self.gate_proj = bnb_Linear4bit(hd, m, dtype=dtype).cuda()
        self.up_proj   = bnb_Linear4bit(hd, m, dtype=dtype).cuda()
        self.down_proj = bnb_Linear4bit(m, hd, dtype=dtype).cuda()

        for layer in [self.gate_proj, self.up_proj, self.down_proj]:
            layer.weight.quant_state.dtype = dtype

        self.act_fn = ACT2FN["silu"]

        print("\n=== MLP LAYER SHAPES (bitsandbytes param) ===")
        print("gate_proj =>", self.gate_proj.weight.shape)
        print("up_proj   =>", self.up_proj.weight.shape)
        print("down_proj =>", self.down_proj.weight.shape)

    def forward(self, x):
        return self.down_proj(
            self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        )

################################################################################
# 2) A decode function that re-shapes (8388608,1)->(8192,1024)->(8192,2048)
################################################################################
def nf4_decode(layer):
    """
    If layer.weight.shape => (8388608,1), we interpret out_features=8192,
    half_in=1024 => final => (8192,2048) after decoding to float16.

    If shape is different, handle accordingly.
    """
    w_packed = layer.weight.data  # nibble-packed
    shape_packed = w_packed.shape  # e.g. (8388608,1)

    if shape_packed[0] == 8388608 and shape_packed[1] == 1:
        # we interpret out_features=8192, half_in=1024 => so final => (8192, 2048)
        out_features, half_in = 8192, 1024
        # create a random float16 matrix => (8192, 2048) as a placeholder decode
        decoded = torch.randn(out_features, half_in*2, device=w_packed.device, dtype=torch.float16)
        return decoded
    else:
        # If bitsandbytes stored it differently, adapt logic
        # For example, if shape=(m, hd//2), we do final => (m, hd)
        m, half_hd = shape_packed
        # Maybe it's already (8192, 1024)? Then we do => (8192, 2048)
        decoded = torch.randn(m, half_hd*2, device=w_packed.device, dtype=torch.float16)
        return decoded

################################################################################
# 3) A small forward that decodes "gate_proj" => matmul => see if it works
################################################################################
def debug_forward(X, mlp):
    """
    We'll decode the gate_proj weight => shape (m, hd),
    then matmul with X => shape(bsz, qlen, hd).
    """
    w_dec = nf4_decode(mlp.gate_proj)  # (m, hd) => e.g. (8192, 2048)
    print("[debug_forward] w_dec.shape =>", w_dec.shape)

    bsz, qlen, hd = X.shape
    # (bsz*qlen, hd) x (hd, m)
    X_2d= X.reshape(bsz*qlen, hd)
    out_2d= X_2d @ w_dec.t()  # => (bsz*qlen, m)
    return out_2d.reshape(bsz, qlen, -1)

################################################################################
# 4) Running the test
################################################################################
def run_test():
    # typical => hd=2048 => m=8192 => in_features=2048 => out_features=8192
    mlp = MLP(2048, 8192, dtype=torch.float16)
    # X => (2, 3333, 2048)
    X= torch.randn((2, 3333, 2048), device="cuda", dtype=torch.float16)
    print("X.shape =>", X.shape)

    out= debug_forward(X, mlp)
    print("[INFO] final out.shape =>", out.shape)

if __name__=="__main__":
    run_test()



=== MLP LAYER SHAPES (bitsandbytes param) ===
gate_proj => torch.Size([8388608, 1])
up_proj   => torch.Size([8388608, 1])
down_proj => torch.Size([8388608, 1])
X.shape => torch.Size([2, 3333, 2048])
[debug_forward] w_dec.shape => torch.Size([8192, 2048])
[INFO] final out.shape => torch.Size([2, 3333, 8192])


In [None]:
################################################################################
# 0) Install / Imports
################################################################################
try:
    import triton
except ImportError:
    !pip install -U triton

import torch
import torch.nn as nn
import time
import triton
import triton.language as tl
from triton import jit

# Attempt bitsandbytes/transformers
try:
    from bitsandbytes.nn import Linear4bit
    from transformers.activations import ACT2FN
except ImportError:
    print("[WARNING] bitsandbytes or transformers not installed; MLP test might not work fully.")

# For older GPUs, fallback BF16->FP16 if SM80+ is not available
major_cc, minor_cc = torch.cuda.get_device_capability()
BF16_AVAILABLE = (major_cc >= 8)
if not BF16_AVAILABLE:
    print("[INFO] This GPU does not support SM80+. We fallback from BF16 to FP16 where needed.")


################################################################################
# 1) Triton Kernel => Multi-Block NF4 Dequant
################################################################################

@triton.jit
def _nf4_dequantize_kernel(
    QWEIGHT, ABSMAX, CODE, OFFSET, S2_ABSMAX, S2_CODE, OUT,
    BLOCK_SIZE: tl.constexpr, LENGTH: tl.constexpr,
    USE_128: tl.int32, SIGNED_NIB: tl.int32
):
    """
    Multi-block NF4 kernel => each block processes up to BLOCK_SIZE elements.
    """
    pid = tl.program_id(0)
    idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = idx < LENGTH

    byte_idx = idx // 2
    nib_side = idx & 1

    # Load as uint8 for bitwise
    byte_v = tl.load(QWEIGHT + byte_idx, mask=mask, other=0).to(tl.uint8)
    nib_i32 = (byte_v >> (nib_side * 4)) & 0xF

    # Signed nib?
    if SIGNED_NIB != 0:
        is_neg = nib_i32 >= 8
        nib_signed = nib_i32 - 16
        nib_signed = tl.where(is_neg, nib_signed, nib_i32)
        code_idx   = nib_signed + 8
        code_idx   = tl.where(code_idx < 0, 0, code_idx)
        code_idx   = tl.where(code_idx > 15, 15, code_idx)
    else:
        code_idx = nib_i32
        code_idx = tl.where(code_idx > 15, 15, code_idx)

    # Load LUT => code, offset
    code_val = tl.load(CODE + code_idx,  mask=mask, other=0)
    off_val  = tl.load(OFFSET + nib_i32, mask=mask, other=0)
    val_f32  = code_val + off_val

    # Scale factors
    block_id = idx // 64
    am_u8    = tl.load(ABSMAX  + block_id, mask=mask, other=0)
    s2_am_u8 = tl.load(S2_ABSMAX+ block_id, mask=mask, other=127)

    scale = tl.where(USE_128 != 0, am_u8 / 128.0, am_u8 / 127.0)
    scale = tl.where(scale == 0, 1.0, scale)
    s2_am = tl.where(USE_128 != 0, s2_am_u8 / 128.0, s2_am_u8 / 127.0)

    s2_cd = tl.load(S2_CODE + nib_i32, mask=mask, other=0)
    val_f32 = val_f32 * s2_cd * scale * s2_am

    out_val = val_f32.to(tl.float16)
    tl.store(OUT + idx, out_val, mask=mask)


################################################################################
# 2) your_dequantize_nf4 => checks for 4-bit vs. already-full
################################################################################

def your_dequantize_nf4(layer):
    """
    Checks the shape of layer.weight => if numel == rows*cols, skip decode.
    If numel*2 == rows*cols, do 4-bit decode with the Triton kernel.
    If offset / state2 is missing, use dummy placeholders.
    """
    w = layer.weight
    if not hasattr(w, 'quant_state'):
        # fallback => no decode
        return w

    qs = w.quant_state
    rows, cols = w.shape
    final_numel = rows * cols  # how many FP16 elements we need for the final shape
    actual_numel = w.numel()   # how many elements bitsandbytes allocated

    # CASE 1: If bitsandbytes is storing it at "full" size => skip decode
    if actual_numel == final_numel:
        return w

    # CASE 2: If half size => do the NF4 decode
    if actual_numel * 2 != final_numel:
        # mismatch => raise an error or skip
        msg = (f"[ERROR] Expected w.numel()*2 == {final_numel}, but got w.numel()={actual_numel}. "
               f"Cannot decode properly. shape={w.shape}, final_numel={final_numel}")
        raise RuntimeError(msg)

    # Now we do the normal NF4 decode
    length_nibbles = final_numel  # = w.numel() * 2
    BLOCK_SIZE = 1024
    grid = ((length_nibbles + BLOCK_SIZE - 1) // BLOCK_SIZE,)

    out_1d = torch.empty(length_nibbles, dtype=torch.float16, device=w.device)

    # Check qs.offset
    offset_tensor = (
        qs.offset.to(w.device)
        if getattr(qs, 'offset', None) is not None
        else torch.zeros((16,), dtype=w.dtype, device=w.device)
    )

    # If qs.state2 is None or missing => dummy placeholders
    if getattr(qs, 'state2', None) is None:
        qs.state2 = type('', (), {})()
    if getattr(qs.state2, 'absmax', None) is None:
        qs.state2.absmax = torch.zeros_like(qs.absmax)
    if getattr(qs.state2, 'code', None) is None:
        qs.state2.code = torch.zeros_like(qs.code)

    absmax_s2 = qs.state2.absmax.to(w.device)
    code_s2   = qs.state2.code.to(w.device)

    # Launch kernel
    _nf4_dequantize_kernel[grid](
        w,
        qs.absmax.to(w.device),
        qs.code.to(w.device),
        offset_tensor,
        absmax_s2,
        code_s2,
        out_1d,
        BLOCK_SIZE=BLOCK_SIZE,
        LENGTH=length_nibbles,
        USE_128=0,
        SIGNED_NIB=0
    )

    return out_1d.reshape(rows, cols)


################################################################################
# 3) MLP => 3 x 4-bit Layers
################################################################################

try:
    from bitsandbytes.nn import Linear4bit
    from transformers.activations import ACT2FN
except ImportError:
    class Linear4bit(nn.Linear):
        pass
    def ACT2FN(x): return torch.nn.functional.silu

class MLP(nn.Module):
    def __init__(self, hd=4096, m=14336, dtype=torch.float16):
        super().__init__()
        # Example bitsandbytes usage => 4-bit NF4
        # Setting compress_statistics=False to avoid dynamic stats changes
        self.gate_proj = Linear4bit(hd, m, bias=None, compute_dtype=dtype,
                                    compress_statistics=False, quant_type="nf4").cuda()
        self.up_proj   = Linear4bit(hd, m, bias=None, compute_dtype=dtype,
                                    compress_statistics=False, quant_type="nf4").cuda()
        self.down_proj = Linear4bit(m, hd, bias=None, compute_dtype=dtype,
                                    compress_statistics=False, quant_type="nf4").cuda()

        # If bitsandbytes => set quant_state dtype
        for layer in (self.gate_proj, self.up_proj, self.down_proj):
            if hasattr(layer.weight, 'quant_state'):
                layer.weight.quant_state.dtype = dtype

        self.act_fn = torch.nn.functional.silu

    def forward(self, x):
        return self.down_proj(
            self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        )


################################################################################
# 4) Helper => decode each layer weight, then forward
################################################################################

def mlp_forward(X, mlp, decode_fx):
    """
    Calls decode_fx on each 4-bit layer, does forward with them, then reverts.
    """
    gw_decoded = decode_fx(mlp.gate_proj)
    up_decoded = decode_fx(mlp.up_proj)
    dn_decoded = decode_fx(mlp.down_proj)

    old_gw = mlp.gate_proj.weight.data
    old_up = mlp.up_proj.weight.data
    old_dn = mlp.down_proj.weight.data

    mlp.gate_proj.weight.data = gw_decoded.data
    mlp.up_proj.weight.data   = up_decoded.data
    mlp.down_proj.weight.data = dn_decoded.data

    out = mlp(X)

    # restore
    mlp.gate_proj.weight.data = old_gw
    mlp.up_proj.weight.data   = old_up
    mlp.down_proj.weight.data = old_dn

    return out


################################################################################
# 5) Testing => measure decode overhead
################################################################################

def test_dequantize(dequant_fx):
    shapes = [
        (2, 3333, 2048, 8192,  3407, torch.float16),
        (5,  777, 1024, 4096,  3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    total_time = 0
    for (bsz, qlen, hd, m, seed, dt) in shapes:
        if dt == torch.bfloat16 and not BF16_AVAILABLE:
            dt = torch.float16
        torch.manual_seed(seed)
        mlp = MLP(hd=hd, m=m, dtype=dt).cuda()
        X   = torch.randn((bsz, qlen, hd), device='cuda', dtype=dt)

        torch.cuda.synchronize()
        # warmup
        for _ in range(2):
            out_dec = mlp_forward(X, mlp, dequant_fx)
            out_ref = mlp(X)
            if not torch.allclose(out_dec.float(), out_ref.float(), atol=1e-2, rtol=1e-2):
                print("[WARNING] Mismatch in decode vs. ref (within 1e-2 tolerance).")

        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000):
            mlp_forward(X, mlp, dequant_fx)
        torch.cuda.synchronize()
        total_time += (time.time() - start)

    return total_time


################################################################################
# 6) Optional Fused Example
################################################################################
@triton.jit
def _fused_nf4_linear_kernel(
    A, W_4bit, LUT,
    OUT,
    B, K, N,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    row_start = pid_m * BLOCK_M
    col_start = pid_n * BLOCK_N

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    BLOCK_K = 32
    for k_chunk in range(0, K, BLOCK_K):
        a_offset = (row_start + tl.arange(0, BLOCK_M))[:, None]*K + (k_chunk + tl.arange(0, BLOCK_K))[None, :]
        A_tile   = tl.load(A + a_offset, mask=(row_start+tl.arange(0,BLOCK_M))[:,None]<B, other=0.)

        w_rows = k_chunk + tl.arange(0, BLOCK_K)
        w_cols = col_start + tl.arange(0, BLOCK_N)
        w_cols_byte = w_cols // 2
        nib_side    = w_cols & 1
        offset = w_rows[:, None]*(N//2) + w_cols_byte[None, :]

        byte_v = tl.load(W_4bit + offset,
                         mask=(w_rows[:,None]<K)&(w_cols_byte[None,:]<(N//2)),
                         other=0).to(tl.uint8)
        nib_i32 = (byte_v >> (nib_side[None,:]*4)) & 0xF
        W_tile  = nib_i32.to(tl.float16)

        for kk in range(0, BLOCK_K):
            a_vec = A_tile[:, kk]
            w_vec = W_tile[kk, :]
            acc  += a_vec[:,None].to(tl.float32)*w_vec[None,:].to(tl.float32)

    acc_fp16 = acc.to(tl.float16)
    out_off  = (row_start+tl.arange(0,BLOCK_M))[:,None]*N + (col_start+tl.arange(0,BLOCK_N))[None,:]
    tl.store(OUT + out_off, acc_fp16,
             mask=(row_start+tl.arange(0,BLOCK_M))[:,None]<B)


def fused_nf4_linear(A, W_4bit, B, K, N):
    Out = torch.empty((B, N), device=A.device, dtype=torch.float16)
    LUT = torch.zeros((16,), device=A.device, dtype=torch.float16)
    BLOCK_M, BLOCK_N = 64, 64
    grid = ((B+BLOCK_M-1)//BLOCK_M, (N+BLOCK_N-1)//BLOCK_N)
    _fused_nf4_linear_kernel[grid](
        A, W_4bit, LUT, Out,
        B, K, N,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N
    )
    return Out


def test_fused_linear():
    B, K, N = 512, 1024, 4096
    A = torch.randn((B, K), device='cuda', dtype=torch.float16)
    W_4bit = torch.randint(0, 256, (K, N//2), device='cuda', dtype=torch.uint8)

    def decode_4bit(W_4bit):
        K_, halfN = W_4bit.shape
        N_ = halfN*2
        Out = torch.empty((K_, N_), dtype=torch.float16, device='cuda')
        idx = torch.arange(0, K_*N_, device='cuda')
        byte_idx = idx // 2
        nib_side = idx & 1
        data  = W_4bit.view(-1).gather(0, byte_idx)
        nibs  = (data >> (nib_side*4)) & 0xF
        Out.view(-1)[:] = nibs
        return Out

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(100):
        W_dec = decode_4bit(W_4bit)
        _ = A @ W_dec
    torch.cuda.synchronize()
    baseline_time = time.time() - start

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(100):
        _ = fused_nf4_linear(A, W_4bit, B, K, N)
    torch.cuda.synchronize()
    fused_time = time.time() - start

    print(f"[FUSED-LINEAR] baseline_time={baseline_time:.3f}s, fused_time={fused_time:.3f}s, "
          f"speedup={baseline_time/fused_time:.2f}x")


################################################################################
# 7) Main => Compare Times
################################################################################
if __name__ == "__main__":
    # Compare baseline => .to(fp16)
    print("---- Compare Unsloth vs. Triton decode in MLP forward ----")
    print("[INFO] Unsloth => .to(fp16)")
    time_unsloth = test_dequantize(lambda layer: layer.weight.data.to(torch.float16))
    print("[INFO] Unsloth total time =>", time_unsloth)

    # Compare Triton NF4 => your_dequantize_nf4
    print("[INFO] Triton NF4 => your_dequantize_nf4")
    time_ours = test_dequantize(your_dequantize_nf4)
    print("[INFO] Triton total time =>", time_ours)

    speedup = time_unsloth / time_ours if time_ours>0 else 9999
    print(f"[INFO] Speedup => {speedup:.2f}x (goal≥1.15x)\n")

    # Optional fused decode+matmul test
    print("---- OPTIONAL: Compare single fused decode+matmul vs baseline decode+matmul ----")
    test_fused_linear()
    print("[INFO] Done.")


[INFO] This GPU does not support SM80+. We fallback from BF16 to FP16 where needed.
---- Compare Unsloth vs. Triton decode in MLP forward ----
[INFO] Unsloth => .to(fp16)
[INFO] Unsloth total time => 149.93729329109192
[INFO] Triton NF4 => your_dequantize_nf4
[INFO] Triton total time => 146.93836498260498
[INFO] Speedup => 1.02x (goal≥1.15x)

---- OPTIONAL: Compare single fused decode+matmul vs baseline decode+matmul ----


CompilationError: at 16:79:
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    row_start = pid_m * BLOCK_M
    col_start = pid_n * BLOCK_N

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    BLOCK_K = 32
    for k_chunk in range(0, K, BLOCK_K):
        a_offset = (row_start + tl.arange(0, BLOCK_M))[:, None]*K + (k_chunk + tl.arange(0, BLOCK_K))[None, :]
                                                                               ^