<a href="https://colab.research.google.com/github/byi8220/unsloth-puzzles/blob/main/Problem1/TritonNF4Kernel-on-T4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unsloth Problem 1 - Convert nf4 to Triton

Run on a Tesla T4 colab instance

(Note: Tesla T4 does not support `bfloat16`. Since we must use a T4, we can only do regular `float16`.)

## Problem Statement
---
---
---
<a name="NF4"></a>
## A) Convert `nf4` to Triton. [Difficulty: Hard] [Max points: 14]

1. Goal: Convert a `nf4` quantized tensor into `fp16` or `bf16` into a *single* Triton kernel The double dequant of the `absmax` and weight forming must be done in 1 Triton kernel. Must work on Tesla T4.
2. Must be faster than Unsloth's `fast_dequantize` by 1.15x or more, and not use large intermediate memory buffers.
3. Must not use `torch.compile`, but can use `trace.enabled` to help on writing Triton kernels.
4. Good material: [Unsloth `fast_dequantize` function](https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/utils.py#L128), also [bitsandbytes `dequantize_blockwise`](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/86b6c37a8ad448230cedb60753f63150b603a112/bitsandbytes/functional.py#L958)
5. Use `test_dequantize_function` to test your implementation.
6. No CUDA allowed. Custom CUDA inside of the Triton is allowed.
7. Watch Tim's videos on Youtube: [8-bit Optimizers](https://www.youtube.com/watch?v=2ETNONas068)

In [None]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl
!pip install triton==3.1.0 # (https://github.com/unslothai/unsloth/issues/1604)
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

!pip install --no-deps unsloth==2025.3.4 # Stick to stable version
!pip install --no-deps unsloth_zoo==2025.3.4 # Stick to stable version

In [None]:
# Unsloth yells at me to import it before transformers.
import unsloth

# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    # Tolerances loosened due to https://x.com/danielhanchen/status/1893177157733490920
    try: torch.testing.assert_close(x, y, check_stride = True, atol=0.001, rtol=0.001)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0+cu124)
    Python  3.11.11 (you have 3.11.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx, compile=False):
    elapsed = 0
    # Note: The latter two won't actually run in bf16 on a T4.
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        if not HAS_BFLOAT16 and dt == torch.bfloat16:
            dt = torch.float16 # Coerce to float16 for T4 instances
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        if compile:
            mlp = torch.compile(mlp)
            dequantize_fx = torch.compile(dequantize_fx)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()
        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

For example, we can test our implementation via:

In [None]:
from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)

4.58371901512146

The elapsed time for our implementation over 1000 trials is 5.38 seconds or so.

PEFT also has one, which should be mostly identical to Unsloth's version, albeit slightly slower.

In [None]:
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
test_dequantize(peft_dequantize)

4.694403648376465

Write your Triton kernel below, and test it:

In [None]:
from triton import jit
import triton
import triton.language as tl

@triton.jit
def _your_dequantize_nf4_kernel(w_ptr, absmax_ptr, absmax2_ptr, out_ptr,
                                code_ptr, code2_ptr,  # Can we make these constexpr somehow?,
                                num_blocks: tl.constexpr,
                                num_elements: tl.constexpr,
                                n_absmax: tl.constexpr,
                                n_absmax2: tl.constexpr,
                                n_out: tl.constexpr,
                                offset: tl.constexpr,
                                kernel_dtype: tl.constexpr,
                                blocksize: tl.constexpr,
                                blocksize2: tl.constexpr,):
    # Contiguous Stride Solution
    # We know that absmax, absmax2, and w are contiguous
    # Therefore, for each program_id we can process slices of `absmax`, provided they all share the same absmax2.
    # If this is insufficient we can generalize this to slicing over absmax2.
    first_block = tl.program_id(0) * num_blocks # What is the first absmax block we are processing
    last_block = first_block + (num_blocks-1)
    # Assert all absmax1 blocks share an absmax2 block
    block2 = first_block // blocksize2
    last_block2 = last_block // blocksize2
    tl.device_assert(block2 == last_block2)
    absmax2 = tl.load(absmax2_ptr + block2, mask=block2 < n_absmax2)

    # Read the absmax blocks we want
    absmax_read_range = first_block + tl.arange(0, num_blocks)
    absmax_ix = tl.load(absmax_ptr + absmax_read_range, mask=absmax_read_range < n_absmax).cast(tl.uint16) # Must upcast due to https://github.com/triton-lang/triton/issues/6043
    absmax_codes = tl.load(code2_ptr + absmax_ix, mask = absmax_ix < 256)
    offsetted_absmax = tl.fma(absmax_codes, absmax2, offset)

    # Load the slice of `w_ptr` we are working with
    first_element = first_block * blocksize
    w_offset = first_element // 2
    w_range = w_offset + tl.arange(0, num_elements // 2)
    n_w = n_out // 2
    w = tl.load(w_ptr + w_range, mask=w_range < n_w)
    unpacked_w = tl.interleave(w >> 4, w & 0xF).cast(tl.uint16)

    #`gather` is not supported in triton 3.1.0 or 3.2.0: https://github.com/triton-lang/triton/issues/5826
    output = tl.load(code_ptr + unpacked_w, mask=unpacked_w < 16).reshape((num_blocks, blocksize))
    offsetted_absmax = offsetted_absmax.expand_dims(-1)
    write_out = output * offsetted_absmax
    write_out = write_out.reshape((num_elements,))
    o_offset = first_element
    o_range = o_offset + tl.arange(0, num_elements)
    tl.store(out_ptr + o_range, write_out, mask=o_range<n_out, cache_modifier=".cs") # We don't need the output in cache, it's never reused
    return

TORCH_TO_TRITON_DTYPE = {
    torch.float16  : tl.float16,
    torch.bfloat16 : tl.bfloat16,
    torch.float32  : tl.float32
}


def _your_dequantize_nf4(weight, quant_state):
    ### SETUP TRITON LAUNCH HERE
    kernel_dtype = quant_state.dtype
    if not HAS_BFLOAT16 and quant_state.dtype == torch.bfloat16:
        kernel_dtype = torch.float16 # Coerce to float16 for T4 instance
    out = torch.empty(quant_state.shape,
                      dtype=kernel_dtype,
                      device=weight.device,
                      requires_grad = False)
    is_transposed = weight.shape[0] == 1
    n_out = out.numel()
    n_absmax = quant_state.absmax.numel()
    n_absmax2 = quant_state.state2.absmax.numel()

    ov = out.view(-1)
    grid = (n_absmax // 64,)

    num_blocks = n_absmax // grid[0]
    compiled_kernel = _your_dequantize_nf4_kernel[grid](weight, quant_state.absmax,
                                      quant_state.state2.absmax, ov,
                                      quant_state.code,
                                      quant_state.state2.code,
                                      num_blocks=num_blocks,
                                      num_elements=num_blocks * quant_state.blocksize,
                                      n_absmax=n_absmax,
                                      n_absmax2=n_absmax2,
                                      n_out=n_out,
                                      offset=quant_state.offset.item(),
                                      kernel_dtype=TORCH_TO_TRITON_DTYPE[kernel_dtype],
                                      blocksize=quant_state.blocksize,
                                      blocksize2=quant_state.state2.blocksize)
    torch.cuda.synchronize()

    out = ov.view(out.shape)
    if is_transposed:
        return out.transpose()
    else:
        return out

def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)

In [None]:
set_seed(3407)
a = bnb_Linear4bit(2048, 8192, dtype = torch.float16).to("cuda")
a.weight.quant_state.dtype = torch.float16

expected = unsloth_dequantize(a)
actual = your_dequantize_nf4(a)

torch.testing.assert_close(expected, actual, atol=0.001, rtol=0.001)

Note that above, we see a slight difference in our dequantization. This could possibly be a bug, or possibly an issue with CUDA.

In [None]:
print("Can we use BFLOAT16:", HAS_BFLOAT16)
# TEST IT BELOW:
RUNS = 5
bench = []
for _ in range(RUNS):
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    dequant_time = test_dequantize(your_dequantize_nf4)

    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    reference_time = test_dequantize(unsloth_dequantize)

    print("Triton kernel time:", dequant_time)
    print("Reference unsloth kernel time:", reference_time)
    ### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
    # Somehow, it is!
    # The tolerances are really loose (1e-3 rtol and atol)
    ratio = reference_time / dequant_time
    bench.append(ratio)
    print(ratio)
print("Average runtime ratio:", sum(bench)/len(bench))

Can we use BFLOAT16: False
Triton kernel time: 4.000486612319946
Reference unsloth kernel time: 4.635912656784058
1.1588371880828812
Triton kernel time: 3.976804733276367
Reference unsloth kernel time: 4.66949462890625
1.1741825264473564
Triton kernel time: 3.968602180480957
Reference unsloth kernel time: 4.737403392791748
1.193720906593268
Triton kernel time: 3.9665496349334717
Reference unsloth kernel time: 4.81851053237915
1.214786395194061
Triton kernel time: 3.9662768840789795
Reference unsloth kernel time: 4.902519464492798
1.2360507366926365
Average runtime ratio: 1.1955155506020405


**NOTE:** The result above shows the kernel's performance on a T4 (where we are only testing float16), but is significantly slower on an L4 GPU.

In [None]:
#@title Test with compiled MLP

RUNS = 5
bench = []
for _ in range(RUNS):
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    dequant_time = test_dequantize(your_dequantize_nf4, compile=True)

    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    reference_time = test_dequantize(unsloth_dequantize, compile=True)

    print("Triton kernel time:", dequant_time)
    print("Reference unsloth kernel time:", reference_time)
    ### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
    # The triton kernel isn't compile friendly. We lose time on compilation?
    ratio = reference_time / dequant_time
    bench.append(ratio)
    print(ratio)

print("Average runtime ratio:", sum(bench)/len(bench))

Triton kernel time: 5.698766708374023
Reference unsloth kernel time: 5.064429759979248
0.888688731991318
Triton kernel time: 5.911512136459351
Reference unsloth kernel time: 5.176299333572388
0.875630331814338
Triton kernel time: 5.882539987564087
Reference unsloth kernel time: 5.290977716445923
0.8994376115812643
Triton kernel time: 5.869607210159302
Reference unsloth kernel time: 5.25044059753418
0.8945131095734228
Triton kernel time: 5.866735935211182
Reference unsloth kernel time: 5.166455030441284
0.8806353460419231
Average runtime ratio: 0.8877810262004532


Misc. Functions

### Kernel Parameter Sweep

With GPU code, kernel launch parameters can dramatically affect performance. Selecting good parameters can be tricky and input shape dependent.

In total we have 3 knobs, `(size, num_warps, num_stages)` in which to tune our parameters in.

In [None]:
#  Parameter Sweep
RUNS = 1

SIZE_DENOM = [1, 2, 4, 8, 16, 32, 64, 128]

WARPS = [1,2,4]

# -1 = programatically guess the optimal.
STAGES = [1, 2, 3, 4, 8]

def _your_dequantize_nf4_param(weight, quant_state, size=1, warps=1, stages=1):
    ### SETUP TRITON LAUNCH HERE
    kernel_dtype = quant_state.dtype
    if not HAS_BFLOAT16 and quant_state.dtype == torch.bfloat16:
        kernel_dtype = torch.float16 # Coerce to float16 for T4 instance
    out = torch.empty(quant_state.shape,
                      dtype=kernel_dtype,
                      device=weight.device,
                      requires_grad = False)
    is_transposed = weight.shape[0] == 1
    n_out = out.numel()
    n_absmax = quant_state.absmax.numel()
    n_absmax2 = quant_state.state2.absmax.numel()

    ov = out.view(-1)
    grid = (n_absmax // size,)

    num_blocks = n_absmax // grid[0]
    compiled_kernel = _your_dequantize_nf4_kernel[grid](weight, quant_state.absmax,
                                      quant_state.state2.absmax, ov,
                                      quant_state.code,
                                      quant_state.state2.code,
                                      num_blocks=num_blocks,
                                      num_elements=num_blocks * quant_state.blocksize,
                                      n_absmax=n_absmax,
                                      n_absmax2=n_absmax2,
                                      n_out=n_out,
                                      offset=quant_state.offset.item(),
                                      kernel_dtype=TORCH_TO_TRITON_DTYPE[kernel_dtype],
                                      blocksize=quant_state.blocksize,
                                      blocksize2=quant_state.state2.blocksize)
    torch.cuda.synchronize()

    out = ov.view(out.shape)
    if is_transposed:
        return out.transpose()
    else:
        return out

def your_dequantize_nf4_param(weight, size=1, warps=1, stages=1):
    return _your_dequantize_nf4_param(weight.weight.data, weight.weight.quant_state, size=size, warps=warps, stages=stages)

from functools import partial
for sz in SIZE_DENOM:
    for warp in WARPS:
        for stage in STAGES:
            bench = []
            parameterized_dequant = partial(your_dequantize_nf4_param, size=sz, warps=warp, stages=stage)
            for _ in range(RUNS):
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
                dequant_time = test_dequantize(parameterized_dequant)

                torch.cuda.synchronize()
                torch.cuda.empty_cache()
                reference_time = test_dequantize(unsloth_dequantize)

                ### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
                # (It's not. It's actually quite subpar. I tried :P)
                ratio = reference_time / dequant_time
                bench.append(ratio)
                print(f"{(sz, warp, stage)} - Average runtime ratio:", sum(bench)/len(bench))

(1, 1, 1) - Average runtime ratio: 0.4117947205358429
(1, 1, 2) - Average runtime ratio: 0.41324931640489404
(1, 1, 3) - Average runtime ratio: 0.4132550341887066
(1, 1, 4) - Average runtime ratio: 0.41390663854701126
(1, 1, 8) - Average runtime ratio: 0.41316792619909765
(1, 2, 1) - Average runtime ratio: 0.4148985525419137
(1, 2, 2) - Average runtime ratio: 0.4124578548290544
(1, 2, 3) - Average runtime ratio: 0.41335346684326457
(1, 2, 4) - Average runtime ratio: 0.4168257273399646
(1, 2, 8) - Average runtime ratio: 0.41670801757226655
(1, 4, 1) - Average runtime ratio: 0.41539329132416014
(1, 4, 2) - Average runtime ratio: 0.41341721693988975
(1, 4, 3) - Average runtime ratio: 0.41152206717844486
(1, 4, 4) - Average runtime ratio: 0.4125461569600087
(1, 4, 8) - Average runtime ratio: 0.4147916797269728
(2, 1, 1) - Average runtime ratio: 0.6808346438779074
(2, 1, 2) - Average runtime ratio: 0.6855004614324384
(2, 1, 3) - Average runtime ratio: 0.6844380209006784
(2, 1, 4) - Average 

As we can see, the performance of the kernel varies drastically, depending on if we pick good kernel params or not.