In [1]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps "unsloth>=2025.3.8" "unsloth_zoo>=2025.3.7" --upgrade --force-reinstall

In [5]:
# NO MORE INVALID/UNAVAILABLE TRITON :)

import triton
import triton.language as tl

# List of Triton language functions and constructs used in the script.
functions_to_check = [
    "inline_asm_elementwise",  # used for inline PTX assembly
    "load",                    # for loading from memory
    "store",                   # for storing to memory
    "arange",                  # for generating index arrays
    "program_id",              # for obtaining the block ID
    "sum",                     # for summing an array of values
    "cast",                    # for type conversions
    "gather",                  # for gathering from a lookup table
    "constexpr",
    "alloc",
    "bitcast"
]

print("Checking functions in triton.language (tl):")
for fn in functions_to_check:
    if hasattr(tl, fn):
        print(f"  {fn} is available.")
    else:
        print(f"  {fn} is NOT available.")

# Also check for triton.jit in the triton module.
print("\nChecking for triton.jit in the triton module:")
if hasattr(triton, "jit"):
    print("  triton.jit is available.")
else:
    print("  triton.jit is NOT available.")


Checking functions in triton.language (tl):
  inline_asm_elementwise is available.
  load is available.
  store is available.
  arange is available.
  program_id is available.
  sum is available.
  cast is available.
  gather is NOT available.
  constexpr is available.
  alloc is NOT available.
  bitcast is NOT available.

Checking for triton.jit in the triton module:
  triton.jit is available.


In [13]:
#%%writefile occupant_decode_auto_tuner.py
import os
import json
import torch
import torch.nn as nn
import torch._dynamo
import time
import unsloth
import unsloth.kernels.utils as uutils
from transformers import set_seed
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

###############################################################################
# 0) Fallback + unify logic
###############################################################################

orig_fx = uutils.fast_dequantize

def unify_4bit_weight(bnb_weight, old_shape, new_out, new_in):
    import bitsandbytes.functional as bnbF
    qs = bnb_weight.quant_state
    float_data = bnbF.dequantize_4bit(bnb_weight.data, qs).float()
    want_elems = new_out * new_in
    if float_data.numel() != want_elems:
        print(f"[unify_4bit_weight] mismatch => {float_data.numel()} vs {want_elems}. Skipping.")
        return bnb_weight
    float_data = float_data.reshape(new_out, new_in)
    q_w, q_state = bnbF.quantize_4bit(float_data, quant_type="nf4")
    bnb_weight.data = q_w
    bnb_weight.quant_state = q_state
    print(f"[unify_4bit_weight] from shape={old_shape} => ({new_out}x{new_in})")
    return bnb_weight

def unify_all_4bit_shapes(module):
    import bitsandbytes as bnb
    for name, child in module.named_modules():
        if isinstance(child, bnb.nn.Linear4bit):
            w = child.weight
            old_shape = tuple(w.shape)
            e = w.numel()
            guess_cols = [1024, 2048, 4096, 14336, 8192]
            found = False
            for c_ in guess_cols:
                if (e % c_) == 0:
                    r_ = e // c_
                    unify_4bit_weight(w, old_shape, r_, c_)
                    found = True
                    break
            if not found:
                print(f"[WARNING unify_all_4bit_shapes] Could not unify {old_shape}.")

###############################################################################
# 1) Triton Warp-Persistent Kernel
###############################################################################
import triton
import triton.language as tl

# Use the documented name for inline assembly.
if hasattr(tl, "inline_asm_elementwise"):
    @triton.jit
    def inline_ptx_prefetch(ptr):
        """
        Attempt L2 prefetch using inline assembly.
        """
        return tl.inline_asm_elementwise(
            asm_string="prefetch.global.L2 [$0];",
            constraints="r",
            pure=False,
            packed_element=1,
            args=[ptr]
        )
else:
    @triton.jit
    def inline_ptx_prefetch(ptr):
        _ = tl.load(ptr, mask=True, other=0)
        return 0

# Implement a custom gather function using tl.load.
# This function simulates gathering one element from a vector "vec" given an index "idx".
@triton.jit
def custom_gather(vec, idx):
    # Create an index array for 0..15
    indices = tl.arange(0, 16)
    # Compare each element with idx; mask will be 1.0 where equal and 0.0 otherwise.
    mask = tl.eq(indices, idx)
    # Multiply elementwise and sum over the 16 entries.
    return tl.sum(vec * tl.cast(mask, vec.dtype))

@triton.jit
def warp_persistent_decode_kernel(
    W_PTR, CODE_PTR, ABS_PTR, ABS2_PTR, OFFSET,
    Out_PTR,
    EVICT_PTR, EVICT_SIZE,
    total_elems: tl.int32,
    out_dtype_flag: tl.int32,
    use_inline_asm_flag: tl.int32,
    BLOCKS_PER_TB: tl.constexpr
):
    """
    Each block decodes BLOCKS_PER_TB blocks-of-64 nibbles in a loop => warp-persistent approach.
    Uses partial unroll (16 iterations for 64 nibbles) and sums eviction buffer values for an offset.
    """
    tb_id = tl.program_id(0)

    idx = tb_id * 64 + tl.arange(0, 64)
    mask_evct = idx < EVICT_SIZE
    ev_vals  = tl.load(EVICT_PTR + idx, mask=mask_evct, other=0)
    accum    = tl.sum(ev_vals)
    offset_val = OFFSET + accum * 1e-5

    if use_inline_asm_flag:
        inline_ptx_prefetch(W_PTR)

    nblocks = total_elems // 64
    start_block = tb_id * BLOCKS_PER_TB

    for i in range(BLOCKS_PER_TB):
        block_id = start_block + i
        cond_blk = block_id < nblocks
        nib_start = block_id * 64
        cond_elem = nib_start < total_elems
        cond = cond_blk & cond_elem

        ab_u8 = tl.load(ABS_PTR + block_id, mask=cond_blk, other=0)
        sc_f32 = tl.cast(ab_u8, tl.float32) / 127.
        ab2_f32 = tl.load(ABS2_PTR + block_id, mask=cond_blk, other=0)
        final_scale = sc_f32 * ab2_f32 + offset_val

        byte_st = nib_start // 2
        idx_b = byte_st + tl.arange(0, 32)
        cond_b = (idx_b < ((total_elems + 1) // 2)) & cond
        chunk = tl.load(W_PTR + idx_b, mask=cond_b, other=0)

        code_lut = tl.load(CODE_PTR + tl.arange(0, 16))
        # Decode 64 nibbles with partial unroll: 16 iterations, 4 nibbles each.
        for j in range(16):
            bA = chunk[j * 2]
            bB = chunk[j * 2 + 1]
            nibA_lo = bA & 0xF
            nibA_hi = bA >> 4
            nibB_lo = bB & 0xF
            nibB_hi = bB >> 4

            # Replace tl.gather with custom_gather
            valA_lo = custom_gather(code_lut, nibA_lo) * final_scale
            valA_hi = custom_gather(code_lut, nibA_hi) * final_scale
            valB_lo = custom_gather(code_lut, nibB_lo) * final_scale
            valB_hi = custom_gather(code_lut, nibB_hi) * final_scale

            out_idx = nib_start + j * 4
            cLoA = cond & (out_idx < total_elems)
            cHiA = cond & (out_idx + 1 < total_elems)
            cLoB = cond & (out_idx + 2 < total_elems)
            cHiB = cond & (out_idx + 3 < total_elems)

            if out_dtype_flag == 0:
                valA_lo = tl.cast(valA_lo, tl.float16)
                valA_hi = tl.cast(valA_hi, tl.float16)
                valB_lo = tl.cast(valB_lo, tl.float16)
                valB_hi = tl.cast(valB_hi, tl.float16)
            else:
                valA_lo = tl.cast(valA_lo, tl.bfloat16)
                valA_hi = tl.cast(valA_hi, tl.bfloat16)
                valB_lo = tl.cast(valB_lo, tl.bfloat16)
                valB_hi = tl.cast(valB_hi, tl.bfloat16)

            tl.store(Out_PTR + out_idx,   valA_lo, mask=cLoA)
            tl.store(Out_PTR + out_idx + 1, valA_hi, mask=cHiA)
            tl.store(Out_PTR + out_idx + 2, valB_lo, mask=cLoB)
            tl.store(Out_PTR + out_idx + 3, valB_hi, mask=cHiB)

###############################################################################
# 2) The occupant decode approach => separate function that calls the kernel
###############################################################################
@torch._dynamo.disable
def occupant_decode_warp_persistent_impl(weight, quant_state=None, out=None,
                                         use_inline_asm=True, BLOCKS_PER_TB=8):
    shape = getattr(quant_state, "shape", None)
    dt = weight.dtype
    n_elems = shape[0] * shape[1]
    if out is None:
        out = torch.empty(shape, dtype=dt, device=weight.device, requires_grad=False)

    nblocks = n_elems // 64
    gridX = (nblocks + BLOCKS_PER_TB - 1) // BLOCKS_PER_TB
    out_dtype_flag = 0 if dt == torch.float16 else 1
    asm_flag = 1 if use_inline_asm else 0

    evict_size = 4096
    evict_buf = torch.randint(0, 999, (evict_size,), dtype=torch.int32, device=weight.device)

    warp_persistent_decode_kernel[(gridX,)](
        weight,
        quant_state.state2.code,
        quant_state.absmax,
        quant_state.state2.absmax,
        float(quant_state.offset),
        out,
        evict_buf,
        evict_size,
        n_elems,
        out_dtype_flag,
        asm_flag,
        BLOCKS_PER_TB=BLOCKS_PER_TB
    )
    if shape[0] == 1:
        return out.t()
    return out

###############################################################################
# 3) The shape tuner cache => we store globally + load/save to disk
###############################################################################
_SHAPE_TUNER_CACHE = {}

def load_shape_tuner_cache(cache_path="shape_tuner_cache.json"):
    global _SHAPE_TUNER_CACHE
    if os.path.exists(cache_path):
        with open(cache_path, "r") as f:
            data = json.load(f)
        # Convert string keys back to tuples.
        new_cache = {}
        for shape_str, val in data.items():
            shape_tuple = eval(shape_str)
            new_cache[shape_tuple] = val
        _SHAPE_TUNER_CACHE = new_cache
        print(f"[load_shape_tuner_cache] loaded => {len(_SHAPE_TUNER_CACHE)} entries")
    else:
        print("[load_shape_tuner_cache] no file => starting empty")

def save_shape_tuner_cache(cache_path="shape_tuner_cache.json"):
    global _SHAPE_TUNER_CACHE
    sdict = {}
    for shape_tuple, best_BPTB in _SHAPE_TUNER_CACHE.items():
        sdict[str(shape_tuple)] = best_BPTB
    with open(cache_path, "w") as f:
        json.dump(sdict, f)
    print(f"[save_shape_tuner_cache] => wrote file with {len(sdict)} entries")

###############################################################################
# 4) occupant_decode_warp_auto => tries [4,8,16,32], picks best, caches
###############################################################################
@torch._dynamo.disable
def occupant_decode_warp_auto(weight, quant_state=None, out=None, use_inline_asm=True):
    """
    Checks if shape is in _SHAPE_TUNER_CACHE. If yes, calls occupant_decode_warp_persistent_impl;
    if no, tries multiple BLOCKS_PER_TB values, picks the fastest, caches it, and then calls the kernel.
    """
    if quant_state is None:
        quant_state = weight.quant_state
    shape = getattr(quant_state, "shape", None)
    if shape is None or quant_state.blocksize != 64 or tuple(shape) != tuple(weight.shape):
        return orig_fx(weight, quant_state)

    dt = weight.dtype
    n_elems = shape[0] * shape[1]

    global _SHAPE_TUNER_CACHE
    if shape in _SHAPE_TUNER_CACHE:
        best_BPTB = _SHAPE_TUNER_CACHE[shape]
        return occupant_decode_warp_persistent_impl(weight, quant_state, out, use_inline_asm, best_BPTB)
    else:
        # Auto-tune: try [4,8,16,32]
        test_params = [4, 8, 16, 32]
        best_time = 999999
        best_param = 8
        if out is None:
            out_buf = torch.empty(shape, dtype=dt, device=weight.device, requires_grad=False)
        else:
            out_buf = out

        for param_ in test_params:
            t0 = time.time()
            occupant_decode_warp_persistent_impl(weight, quant_state, out_buf, use_inline_asm, param_)
            torch.cuda.synchronize()
            dur = time.time() - t0
            if dur < best_time:
                best_time = dur
                best_param = param_

        _SHAPE_TUNER_CACHE[shape] = best_param
        return occupant_decode_warp_persistent_impl(weight, quant_state, out_buf, use_inline_asm, best_param)

###############################################################################
# 5) Test harness => occupant_decode_warp_auto
###############################################################################
def test_dequantize_function(fxdequant):
    combos = [
        (2, 3333, 2048, 8192, 3407, torch.float16),
        (5,  777, 1024, 4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    total = 0
    for (bsz, qlen, hd, m, seed, dt) in combos:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)

        class MLPTest(nn.Module):
            def __init__(self, hd, m, dt):
                super().__init__()
                self.gate_proj = Linear4bit(hd, m, bias=None, compute_dtype=dt, quant_type="nf4").cuda()
                self.up_proj   = Linear4bit(hd, m, bias=None, compute_dtype=dt, quant_type="nf4").cuda()
                self.down_proj = Linear4bit(m, hd, bias=None, compute_dtype=dt, quant_type="nf4").cuda()
                # Force the quant_state dtype
                self.gate_proj.weight.quant_state.dtype = dt
                self.up_proj.weight.quant_state.dtype   = dt
                self.down_proj.weight.quant_state.dtype = dt
                self.act_fn = nn.SiLU()

            def forward(self, x):
                return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        net = MLPTest(hd, m, dt)
        X = torch.randn((bsz, qlen, hd), device='cuda', dtype=dt)

        def do_dequant():
            g_ = fxdequant(net.gate_proj.weight).t()
            u_ = fxdequant(net.up_proj.weight).t()
            d_ = fxdequant(net.down_proj.weight).t()
            return (g_, u_, d_)

        # Warm-up
        for _ in range(2):
            do_dequant()

        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(1000):
            do_dequant()
        torch.cuda.synchronize()
        total += (time.time() - t0)
    return total

def test_final_method():
    # Patch occupant => occupant_decode_warp_auto
    old_fx = uutils.fast_dequantize
    uutils.fast_dequantize = occupant_decode_warp_auto

    print("\n=== Testing occupant decode warp auto (with shape tuner) ===")
    new_time = test_dequantize_function(occupant_decode_warp_auto)
    print(f"[Occupant decode auto] => {new_time:.4f}s")

    # Revert
    uutils.fast_dequantize = old_fx
    def baseline_dequant(w):
        return orig_fx(w, w.quant_state)

    print("\n=== Testing original unsloth => baseline ===")
    old_time = test_dequantize_function(baseline_dequant)
    print(f"[Unsloth Original] => {old_time:.4f}s")
    speedup = old_time / new_time
    print(f"Speedup => {speedup:.2f}x")

def test_torch_compile():
    """
    Occupant decode is done outside compile => skip nib errors.
    """
    try:
        matmul_compiled = torch.compile(lambda x, w: x @ w.t(), fullgraph=True, dynamic=True)
    except Exception as e:
        matmul_compiled = None
        print("[WARNING] torch.compile not supported =>", e)

    if matmul_compiled is None:
        print("[WARNING] skipping test_torch_compile => none")
        return

    print("\n--- occupant decode warp auto => decode outside => compiled matmul ---")
    from bitsandbytes.nn import Linear4bit
    dt = torch.float16
    l4b = Linear4bit(128, 256, bias=None, compute_dtype=dt, quant_type="nf4").cuda()
    # Decode using occupant auto => tries multiple parameters => caches best.
    w_dec = occupant_decode_warp_auto(l4b.weight)
    x = torch.randn((2, 128), dtype=dt, device='cuda')
    if x.dtype != w_dec.dtype:
        x = x.to(w_dec.dtype)

    out = matmul_compiled(x, w_dec)
    for _ in range(10):
        out = matmul_compiled(x, w_dec)
    print("compiled_matmul => out.shape =>", out.shape)

###############################################################################
# 6) Shape tuner load/save => do at start & end
###############################################################################
def main():
    # Load tuner from disk.
    load_shape_tuner_cache("shape_tuner_cache.json")

    test_final_method()
    test_torch_compile()

    # Save tuner cache to disk.
    save_shape_tuner_cache("shape_tuner_cache.json")

 # if __name__ == "__main__":
 #    main()
 #    import time
 #    time.sleep(3)


In [10]:
test_final_method()


=== Testing occupant decode warp auto (with shape tuner) ===
[Occupant decode auto] => 5.1223s

=== Testing original unsloth => baseline ===
[Unsloth Original] => 5.5637s
Speedup => 1.09x


In [16]:
test_final_method()


=== Testing occupant decode warp auto (with shape tuner) ===
[Occupant decode auto] => 5.2976s

=== Testing original unsloth => baseline ===
[Unsloth Original] => 5.3428s
Speedup => 1.01x
