# NF4 Triton Dequantization Benchmarks

Technical evaluation of NF4 dequantization performance using Triton kernels vs existing implementations.

In [None]:
# Install dependencies
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

# Clone and install the repo if running in Colab
import os
if not os.path.exists('nf4_triton_dequantization'):
    !git clone https://github.com/felipemcoelho/nf4-triton-dequantization.git
    %cd nf4-triton-dequantization
    !pip install -e .

In [None]:
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
import matplotlib.pyplot as plt

# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    major_version, minor_version = torch.cuda.get_device_capability()
    print(f"CUDA capability: {major_version}.{minor_version}")
    print(f"Device: {torch.cuda.get_device_name(0)}")
    HAS_BFLOAT16 = (major_version >= 8)
    print(f"bfloat16 support: {HAS_BFLOAT16}")
else:
    print("CUDA not available. This benchmark requires a CUDA-capable GPU.")

## 1. NF4 Quantization Test Model

In [None]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN

def bnb_Linear4bit(hd, m, dtype=torch.float16):
    return Linear4bit(
        hd, m, bias=None,
        compute_dtype=dtype,
        compress_statistics=True,
        quant_type="nf4",
    )

class MLP(nn.Module):
    def __init__(self, hd=4096, m=14336, dtype=torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype=dtype).to("cuda")
        self.up_proj = bnb_Linear4bit(hd, m, dtype=dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype=dtype).to("cuda")
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj.weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
        
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

# Create a sample MLP with NF4 weights
set_seed(42)
hd, m = 1024, 4096
dtype = torch.float16
mlp = MLP(hd=hd, m=m, dtype=dtype)

# Create a sample input
batch_size, seq_len = 2, 128
x = torch.randn((batch_size, seq_len, hd), device="cuda", dtype=dtype)

# Examine the weight structure
print(f"Weight dtype: {mlp.gate_proj.weight.dtype}")
print(f"Weight shape: {mlp.gate_proj.weight.shape}")
print(f"Compute dtype: {mlp.gate_proj.weight.quant_state.dtype}")
print(f"Blocksize: {mlp.gate_proj.weight.quant_state.blocksize}")
print(f"Secondary blocksize: {mlp.gate_proj.weight.quant_state.state2.blocksize}")

## 2. Comparing Dequantization Methods

In [None]:
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

# Test Unsloth dequantization
start = time.time()
dequant_weight = unsloth_dequantize(mlp.gate_proj)
torch.cuda.synchronize()
unsloth_time = time.time() - start

print(f"Unsloth dequantization time: {unsloth_time*1000:.2f} ms")
print(f"Dequantized dtype: {dequant_weight.dtype}")
print(f"Dequantized shape: {dequant_weight.shape}")

# Test PEFT dequantization
start = time.time()
dequant_weight_peft = peft_dequantize(mlp.gate_proj)
torch.cuda.synchronize()
peft_time = time.time() - start

print(f"PEFT dequantization time: {peft_time*1000:.2f} ms")
print(f"Results match: {torch.allclose(dequant_weight, dequant_weight_peft)}")

## 3. Testing Triton Implementation

In [None]:
from nf4_triton_dequantization import triton_dequantize_nf4

# Test Triton implementation
start = time.time()
dequant_weight_triton = triton_dequantize_nf4(mlp.gate_proj)
torch.cuda.synchronize()
triton_time = time.time() - start

print(f"Triton dequantization time: {triton_time*1000:.2f} ms")
print(f"Results match Unsloth: {torch.allclose(dequant_weight, dequant_weight_triton)}")
print(f"Results match PEFT: {torch.allclose(dequant_weight_peft, dequant_weight_triton)}")

# Calculate speedup
speedup_vs_unsloth = unsloth_time / triton_time
speedup_vs_peft = peft_time / triton_time

print(f"\nSpeedup vs Unsloth: {speedup_vs_unsloth:.2f}x")
print(f"Speedup vs PEFT: {speedup_vs_peft:.2f}x")

## 4. Running Benchmarks

In [None]:
from benchmark import run_benchmarks, plot_benchmarks

# Run benchmarks (reduced iterations for quicker notebook execution)
results = run_benchmarks(iterations=100, warmup=2)

In [None]:
# Plot the results
fig = plot_benchmarks(results)

## 5. Forward Pass Impact

In [None]:
def mlp_forward(X, mlp, fx):
    up = X @ fx(mlp.up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

# Create a larger input for more realistic testing
batch_size, seq_len = 4, 512
x_large = torch.randn((batch_size, seq_len, hd), device="cuda", dtype=dtype)

# Measure time for Unsloth-based forward pass
torch.cuda.synchronize()
start = time.time()
output_unsloth = mlp_forward(x_large, mlp, unsloth_dequantize)
torch.cuda.synchronize()
unsloth_forward_time = time.time() - start

# Measure time for Triton-based forward pass
torch.cuda.synchronize()
start = time.time()
output_triton = mlp_forward(x_large, mlp, triton_dequantize_nf4)
torch.cuda.synchronize()
triton_forward_time = time.time() - start

# Calculate end-to-end speedup
forward_speedup = unsloth_forward_time / triton_forward_time

print(f"Results match in forward pass: {torch.allclose(output_unsloth, output_triton)}")
print(f"Unsloth forward time: {unsloth_forward_time*1000:.2f} ms")
print(f"Triton forward time: {triton_forward_time*1000:.2f} ms")
print(f"End-to-end speedup: {forward_speedup:.2f}x")