In [None]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [1]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

### WARNING: MODIFIED RTOL & ATOL
def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=1e-4, rtol=1e-3)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [2]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.float16),
        (3, 2048, 4096, 14336, 3408, torch.float16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0+cu124)
    Python  3.11.11 (you have 3.11.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


ü¶• Unsloth Zoo will now patch everything to make training faster!


In [3]:
import triton
import triton.language as tl
import math

@triton.jit
def dequantize_nf4_kernel_with_ptx(
    q_ptr, absmax_ptr, code2_ptr, absmax2_ptr, nf4_table_ptr, out_ptr,
    offset: float,
    n_elements: tl.constexpr,
    blocksize_log2: tl.constexpr,
    blocksize2_log2: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    tid = tl.arange(0, BLOCK_SIZE)

    # cant use fma here, block size is constexpr
    elem_idx = pid * BLOCK_SIZE + tid
    mask = elem_idx < n_elements

    byte_idx = elem_idx >> 1
    is_high_nibble = (elem_idx & 1) == 0

    q_byte = tl.load(q_ptr + byte_idx, mask=mask)

    nibble = tl.where(is_high_nibble, (q_byte >> 4) & 0xF, q_byte & 0xF)

    block_idx = elem_idx >> blocksize_log2
    block2_idx = block_idx >> blocksize2_log2

    absmax_idx = tl.load(absmax_ptr + block_idx, mask=mask).to(tl.int32)

    scale1 = tl.load(code2_ptr + absmax_idx, mask=mask)
    scale2 = tl.load(absmax2_ptr + block2_idx, mask=mask)

    # ptx assembly for fma
    # tl.fma()
    #
    final_scale = tl.inline_asm_elementwise(
        """fma.rn.f32 $0, $1, $2, $3;""",
        "=f,f,f,f",
        [scale1, scale2, offset],
        dtype=tl.float32,
        is_pure=True,
        pack=1
    )

    nf4_val = tl.load(nf4_table_ptr + nibble, mask=mask)

    result = nf4_val * final_scale

    # cache eviction, write once mode
    # Use ld.cg to cache loads only globally, bypassing the L1 cache, and cache only in the L2 cache.
    tl.store(out_ptr + elem_idx, result, mask=mask, cache_modifier='.cs')

def my_dequantize_triton(weight):
    q_data = weight.weight.data.view(-1)
    qs = weight.weight.quant_state

    n_elements = weight.out_features * weight.in_features
    blocksize = qs.blocksize
    blocksize2 = qs.state2.blocksize
    nf4_table = qs.code
    absmax = qs.absmax
    code2 = qs.state2.code
    absmax2 = qs.state2.absmax

    offset = qs.offset.item()
    output = torch.empty(n_elements, device=q_data.device, dtype=qs.dtype)

    BLOCK_SIZE = 1024

    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

    blocksize_log2 = int(math.log2(blocksize))
    blocksize2_log2 = int(math.log2(blocksize2))

    dequantize_nf4_kernel_with_ptx[grid](
        q_data, absmax, code2, absmax2, nf4_table, output,
        offset,
        n_elements,
        blocksize_log2, blocksize2_log2,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return output.view(weight.out_features, weight.in_features)

In [4]:
test_dequantize(unsloth_dequantize) # This is the unsloth one

5.146984577178955

In [5]:
test_dequantize(my_dequantize_triton) # This is the triton one

3.869208812713623

In [6]:
### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
test_dequantize(unsloth_dequantize) / test_dequantize(my_dequantize_triton)

1.133116476394387

In [4]:
### torch.compile versions below
# to fix a torch compile problem
!pip uninstall torch -y
!pip install torch torchvision --pre --upgrade
# restart kernel but do not run this block

Found existing installation: torch 2.6.0
Uninstalling torch-2.6.0:
  Successfully uninstalled torch-2.6.0
Collecting torch
  Using cached torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting triton==3.2.0 (from torch)
  Using cached triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Using cached torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl (766.7 MB)
Using cached triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.2 MB)
Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl (7.2 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m7.2/7.2 MB[0m [31m65.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton, torch, torchvision
  Attempting uninstall: triton
    Found existi

In [None]:
import triton
import triton.language as tl
import math

@triton.jit
def dequantize_nf4_kernel_torch_compile(
    q_ptr, absmax_ptr, code2_ptr, absmax2_ptr, nf4_table_ptr, out_ptr,
    offset: float,
    n_elements: tl.constexpr,
    blocksize_log2: tl.constexpr,
    blocksize2_log2: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    tid = tl.arange(0, BLOCK_SIZE)

    elem_idx = pid * BLOCK_SIZE + tid
    mask = elem_idx < n_elements

    byte_idx = elem_idx >> 1
    is_high_nibble = (elem_idx & 1) == 0

    q_byte = tl.load(q_ptr + byte_idx, mask=mask)

    nibble = tl.where(is_high_nibble, (q_byte >> 4) & 0xF, q_byte & 0xF)

    block_idx = elem_idx >> blocksize_log2
    block2_idx = block_idx >> blocksize2_log2

    absmax_idx = tl.load(absmax_ptr + block_idx, mask=mask).to(tl.int32)

    scale1 = tl.load(code2_ptr + absmax_idx, mask=mask)
    scale2 = tl.load(absmax2_ptr + block2_idx, mask=mask)

    # do not use ptx for torch compile
    # do not use fma for torch compile
    final_scale = scale1 * scale2 + offset

    nf4_val = tl.load(nf4_table_ptr + nibble, mask=mask)

    result = nf4_val * final_scale

    # do not use cache eviction for torch compile
    tl.store(out_ptr + elem_idx, result, mask=mask)

In [8]:
@torch.compile(fullgraph=True)
def compiled_kernel(q_data, absmax, code2, absmax2, nf4_table,
                  offset, n_elements, blocksize, blocksize2, dtype):
    output = torch.empty(n_elements, device=q_data.device, dtype=dtype)

    BLOCK_SIZE = 1024
    blocksize_log2 = int(math.log2(blocksize))
    blocksize2_log2 = int(math.log2(blocksize2))

    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

    dequantize_nf4_kernel_torch_compile[grid](
        q_data, absmax, code2, absmax2, nf4_table, output,
        offset,
        n_elements,
        blocksize_log2, blocksize2_log2,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return output

def my_dequantize_triton_torch_compile(weight):
    q_data = weight.weight.data.view(-1)
    qs = weight.weight.quant_state

    n_elements = weight.out_features * weight.in_features
    blocksize = qs.blocksize
    blocksize2 = qs.state2.blocksize
    nf4_table = qs.code
    absmax = qs.absmax
    code2 = qs.state2.code
    absmax2 = qs.state2.absmax

    offset = qs.offset.item()

    output = compiled_kernel(q_data, absmax, code2, absmax2, nf4_table,
                            offset, n_elements, blocksize, blocksize2, qs.dtype)

    return output.view(weight.out_features, weight.in_features)

In [9]:
test_dequantize(unsloth_dequantize) # This is the unsloth one

5.694186687469482

In [10]:
test_dequantize(my_dequantize_triton_torch_compile) # This is the triton with torch compile one

5.683187484741211

In [11]:
### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
test_dequantize(unsloth_dequantize) / test_dequantize(my_dequantize_triton_torch_compile)

1.0133595462093137