In [None]:
import torch
import triton

from peft.utils.integrations import dequantize_module_weight
from bitsandbytes.nn import Linear4bit

from kernel_v0 import dequantize_triton_v0
from kernel_v1 import dequantize_triton_v1
from kernel_v2 import dequantize_triton_v2
from kernel_v3 import dequantize_triton_v3
from kernel_v4 import dequantize_triton_v4


def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

In [None]:
linear = bnb_Linear4bit(128 * 5, 512 * 5, dtype=torch.float16).to("cuda")
linear.quant_state.dtype = torch.float16

out_reference = dequantize_module_weight(linear)
out_triton_v0 = dequantize_triton_v0(linear.weight, linear.quant_state)
out_triton_v1 = dequantize_triton_v1(linear.weight, linear.quant_state)
out_triton_v2 = dequantize_triton_v2(linear.weight, linear.quant_state)
out_triton_v3 = dequantize_triton_v3(linear.weight, linear.quant_state)
out_triton_v4 = dequantize_triton_v4(linear.weight, linear.quant_state)

torch.testing.assert_close(out_reference, out_triton_v0)
torch.testing.assert_close(out_reference, out_triton_v1)
torch.testing.assert_close(out_reference, out_triton_v2)
torch.testing.assert_close(out_reference, out_triton_v3)
torch.testing.assert_close(out_reference, out_triton_v4)

In [None]:
configs = [
    triton.testing.Benchmark(
            x_names=["M", "N"],
            x_vals=[(int(512 * i), int(512 * i)) for i in range(5, 40, 2)], 
            line_arg="provider",
            line_vals=["bitsandbytes", "triton_v0", "triton_v1", "triton_v2", "triton_v3", "triton_v4"],
            line_names=["bitsandbytes", "Triton V0", "Triton V1", "Triton V2", "Triton V3", "Triton V4"],
            styles=[("black", "--"), ("blue", "-"), ("red", "-"), ("orange", "-"), ("purple", "-"), ("yellow", "-")],
            ylabel="ms",
            plot_name="QLORA Dequantize Benchmark-bfloat16",
            args={"bfloat16_instead_of_float16": True},
        ),
    triton.testing.Benchmark(
            x_names=["M", "N"],
            x_vals=[(int(512 * i), int(512 * i)) for i in range(5, 40, 2)], 
            line_arg="provider",
            line_vals=["bitsandbytes", "triton_v0", "triton_v1", "triton_v2", "triton_v3", "triton_v4"],
            line_names=["bitsandbytes", "Triton V0", "Triton V1", "Triton V2", "Triton V3", "Triton V4"],
            styles=[("black", "--"), ("blue", "-"), ("red", "-"), ("orange", "-"), ("purple", "-"), ("yellow", "-")],
            ylabel="ms",
            plot_name="QLORA Dequantize Benchmark-float16",
            args={"bfloat16_instead_of_float16": False},
        ),
]

@triton.testing.perf_report(configs)
def benchmark(M, N, provider, bfloat16_instead_of_float16):
    dtype = torch.bfloat16 if bfloat16_instead_of_float16 else torch.float16
    linear = bnb_Linear4bit(M, N, dtype=dtype).to("cuda")
    linear.quant_state.dtype = dtype

    quantiles = [0.5, 0.2, 0.8]

    # Reference implementations "cuda" from bitsandbytes.
    if provider == 'bitsandbytes':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_module_weight(linear), quantiles=quantiles)

    # Personal implementations (Triton).
    if provider == 'triton_v0':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_triton_v0(linear.weight, linear.quant_state), quantiles=quantiles)
    if provider == 'triton_v1':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_triton_v1(linear.weight, linear.quant_state), quantiles=quantiles)
    if provider == 'triton_v2':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_triton_v2(linear.weight, linear.quant_state), quantiles=quantiles)
    if provider == 'triton_v3':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_triton_v3(linear.weight, linear.quant_state), quantiles=quantiles)
    if provider == 'triton_v4':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_triton_v4(linear.weight, linear.quant_state), quantiles=quantiles)
    return ms, max_ms, min_ms

df = benchmark.run(show_plots=True, print_data=True, return_df=True)