From c6ba287d74192caf75d9fb39fb2952ab9ce05f12 Mon Sep 17 00:00:00 2001 From: sstamenk Date: Fri, 7 Nov 2025 23:54:18 +0000 Subject: [PATCH 1/2] Enable even more unit tests for warp size 32 --- tests/test_functional.py | 6 +++--- tests/test_linear8bitlt.py | 4 ++-- tests/test_parametrize.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 8367c0850..d420ff352 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -491,7 +491,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") + @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -1205,7 +1205,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): @pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")) def test_4bit_quant_large(self, device, dtype, quant_type, blocksize): """ Test that we can successfully quantize a large tensor. Note that the following limitations apply: @@ -1428,7 +1428,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") + @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_gemv_eye_4bit(self, device, storage_type, dtype): if device == "hpu" and not is_supported_on_hpu(storage_type, dtype): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6da3c28f8..a0725d605 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -9,7 +9,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, @@ -234,7 +234,7 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") +@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows": pytest.skip("Triton is not officially supported on Windows") diff --git a/tests/test_parametrize.py b/tests/test_parametrize.py index cf0871c67..be4a6b52c 100644 --- a/tests/test_parametrize.py +++ b/tests/test_parametrize.py @@ -3,7 +3,7 @@ import torch.nn as nn from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.nn.parametrize import ( Bnb4bitParametrization, replace_parameter_4bit, @@ -39,7 +39,7 @@ def __init__(self, device="cpu", dtype=torch.float32): @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize( "blocksize", - [64, 128, 256] if not HIP_ENVIRONMENT else [128, 256], + [64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256], ) def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize): """Test basic parameter replacement with 4-bit quantization on different dtypes.""" @@ -267,7 +267,7 @@ def test_quant_state_preservation(device, dtype): module = ParametrizeTestModule(device=device, dtype=dtype) - blocksize = 128 if HIP_ENVIRONMENT else 64 + blocksize = 128 if ROCM_WARP_SIZE_64 else 64 # Apply parametrization with specific settings replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize) @@ -326,7 +326,7 @@ def test_multiple_parameters(device, dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize( "blocksize", - [64, 128, 256] if not HIP_ENVIRONMENT else [128, 256], + [64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256], ) def test_different_blocksizes(device, dtype, blocksize): """Test parametrization with different block sizes to verify flexibility.""" From e88456ea92ed6a3f46b1098e0905b01f5493aa4f Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Wed, 12 Nov 2025 14:45:08 +0100 Subject: [PATCH 2/2] Revert comment chagnes from previous PR for consistency --- csrc/kernels.hip | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index d1b8fc335..eb139c6ce 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2613,9 +2613,9 @@ template __global__ void kgemm_4bit_inferenc { // per threadblock: - // load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps] + // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] // 4 warps -> 4 loads per iter - // 1 x BNB_WARP_SIZE * BNB_WARP_SIZE x 4 -> 1x4 outputs per thread block + // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block typedef hipcub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];