In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install transformers tf-keras

In [None]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=1e-1, rtol=1e-1)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [None]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

In [4]:
from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)

1.5299334526062012

In [5]:
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
test_dequantize(peft_dequantize)

1.76436448097229

In [None]:
from triton import jit
import triton
import triton.language as tl

@triton.jit
def _your_dequantize_nf4_kernel(w_ptr, w_out, abs_ptrs,
                                offset_ptr, 
                                abs2_ptrs, code2,
                                block_size2,gsize,code,blocks:tl.constexpr, Br:tl.constexpr):
    

    pid = tl.program_id(0)
    
    if pid < gsize:

        absmax_group = pid*blocks + tl.arange(0,blocks) # can be coalesced
        absmax = (tl.load(abs_ptrs + absmax_group,  eviction_policy= "evict_first"))

        lz = tl.inline_asm_elementwise(
        # efficient log2(blockwise2)
        asm="""
        {
            clz.b32 $0, $1;
        }
        """,
        constraints=(
            "=r,r"),
        args=[block_size2],
        dtype=(tl.int32),
        is_pure=True,
        pack=1,
        )

        absmax_group2 = (absmax_group)>>(31-lz)
        real_absmax = tl.load(code2 + absmax,  eviction_policy= "evict_last")
        absmax2 = tl.load(abs2_ptrs + absmax_group2,  eviction_policy= "evict_last")
        offset = tl.load(offset_ptr,  eviction_policy= "evict_last") 
        final_absmax = absmax2 * real_absmax + offset

        w_off = pid*(Br//2) + tl.arange(0, blocks)[:, None]*(Br//(2*blocks)) + tl.arange(0, Br//(2*blocks))[None, :]

        w_packed = tl.load(w_ptr + w_off,  eviction_policy= "evict_first")
        w_packed2 = tl.interleave(w_packed,w_packed)
        

        shift_sh = tl.arange(0, blocks)[:, None]*(Br//(blocks)) + tl.arange(0, Br//(blocks))[None, :]
        shift = tl.where(shift_sh % 2 == 0, 4, 0)

        shifted_w = (w_packed2 >> shift) & 0xF
        
        real_w = tl.load(shifted_w + code, eviction_policy= "evict_last")

        scaled_w = (real_w* final_absmax[:, None])

        out_off = pid*Br + tl.arange(0, blocks)[:, None]*(Br//blocks) + tl.arange(0, Br//blocks)[None, :]
        
        tl.store(w_out + out_off, scaled_w, eviction_policy= "evict_first")
    return

def _your_dequantize_nf4(weight, quant_state):
    device = 'cuda:0'
    out_dtype = quant_state.dtype
    code = quant_state.code

    absmax = quant_state.absmax #uint8
    real_shape = quant_state.shape
    block_size = quant_state.blocksize

    absmax2 = quant_state.state2.absmax #f32
    code2 = quant_state.state2.code #f32
    block_size2 = quant_state.state2.blocksize

    size = weight.shape[0]
    offset = quant_state.offset
    out_size = size*2

    Br = 8192
    blocks = Br//block_size
    
    gsize = (triton.cdiv(out_size, Br))

    DEVICE = torch.device(device)
    props = torch.cuda.get_device_properties(DEVICE)
    if props.major ==8:
        if props.minor == 9: #Ada
            max_th = 24*props.multi_processor_count
        elif props.minor == 0: #Ampere
            max_th = 32*props.multi_processor_count
    elif props.major==7:
        max_th = 16*props.multi_processor_count # Turing
    resto = gsize%max_th
    if resto !=0:
        wave_sze =gsize+ (max_th-resto)

    w_out = torch.empty((real_shape), device=device, dtype = out_dtype 
                        if out_dtype == torch.float16 else torch.bfloat16, requires_grad=False)
    
    grid = lambda META: ((wave_sze,))
    out = _your_dequantize_nf4_kernel[grid](weight, w_out,
                                                         absmax, offset,
                                                         absmax2, code2,
                                                         block_size2, gsize, code, blocks, Br,
                                                         num_warps = 16, num_stages=1, maxnreg=8192,
                                                         )


    return w_out if out_dtype == torch.float16 else w_out

def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)

In [12]:
### TEST IT BELOW:
# test_dequantize(your_dequantize_nf4)

### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
test_dequantize(unsloth_dequantize) / test_dequantize(your_dequantize_nf4)

1.2290114722981054