Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -2613,9 +2613,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{

// per threadblock:
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1 x BNB_WARP_SIZE * BNB_WARP_SIZE x 4 -> 1x4 outputs per thread block
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];

Expand Down
6 changes: 3 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
Expand Down Expand Up @@ -1205,7 +1205,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
"""
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
def test_gemv_eye_4bit(self, device, storage_type, dtype):
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_linear8bit_serialization(linear8bit):
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn

from bitsandbytes import functional as F
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.nn.parametrize import (
Bnb4bitParametrization,
replace_parameter_4bit,
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, device="cpu", dtype=torch.float32):
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize(
"blocksize",
[64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
)
def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize):
"""Test basic parameter replacement with 4-bit quantization on different dtypes."""
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_quant_state_preservation(device, dtype):

module = ParametrizeTestModule(device=device, dtype=dtype)

blocksize = 128 if HIP_ENVIRONMENT else 64
blocksize = 128 if ROCM_WARP_SIZE_64 else 64

# Apply parametrization with specific settings
replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_multiple_parameters(device, dtype):
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize(
"blocksize",
[64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
)
def test_different_blocksizes(device, dtype, blocksize):
"""Test parametrization with different block sizes to verify flexibility."""
Expand Down