Skip to content

Commit

Permalink
Revert "enable cat for cuda bits types (pytorch#115044)"
Browse files Browse the repository at this point in the history
This reverts commit 4cf97c4.

Reverted pytorch#115044 on behalf of https://github.com/malfet due to This breaks ROCM ([comment](pytorch#115044 (comment)))
  • Loading branch information
pytorchmergebot authored and hyperfraise committed Dec 21, 2023
1 parent 4031254 commit ccd3a73
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 69 deletions.
47 changes: 12 additions & 35 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = (scalar_t *)(out.mutable_data_ptr());
scalar_t *data = out.mutable_data_ptr<scalar_t>();
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

Expand Down Expand Up @@ -289,7 +289,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
dimSize = inputs[i+batchCounter].get().size(dimension);
}

catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr());
catMetaData.input[batchCounter] = inputs[i+batchCounter].get().const_data_ptr<scalar_t>();
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel();
Expand Down Expand Up @@ -375,10 +375,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
#undef HANDLE_CASE
}
}
// The kernels are templated on an opaque, self-aligned type of the correct
// size to avoid redundant kernels for different types of the same size.
template <unsigned N> struct alignas(N) OpaqueType { char data[N]; };

} // namespace

TORCH_IMPL_FUNC(cat_out_cuda)
Expand Down Expand Up @@ -416,48 +412,29 @@ TORCH_IMPL_FUNC(cat_out_cuda)
// memory. Therefore, we could pass more inputs to cuda threads.
// For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
// of constant memory.



if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
all_contiguous &&
all32BitIndexable &&
all_same_dtype) {
if (isBitsType(result.scalar_type())) {
AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
} else if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
all32BitIndexable &&
all_same_dtype &&
memory_format == c10::MemoryFormat::Contiguous) {
if (isBitsType(result.scalar_type())) {
AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
} else {
int64_t offset = 0;
for (const Tensor& t : materialized) {
Expand Down
30 changes: 4 additions & 26 deletions test/quantization/core/experimental/test_bits.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Owner(s): ["oncall: quantization"]

import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests

from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map

import itertools

class Int16Tensor(torch.Tensor):
def __new__(cls, elem):
assert elem.dtype == torch.bits16
Expand Down Expand Up @@ -45,42 +41,24 @@ def __repr__(self) -> str:


class TestBits(TestCase):
def test_types(self, device):
def test_types(self):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
_ = torch.zeros(20, dtype=torch.int32, device=device).view(bits_type)
_ = torch.empty(20, dtype=bits_type, device=device)
x = torch.randint(100, (20, 20), dtype=torch.int8, device=device).view(bits_type)
_ = torch.zeros(20, dtype=torch.int32).view(bits_type)
_ = torch.empty(20, dtype=bits_type)
x = torch.randint(100, (20, 20), dtype=torch.int8).view(bits_type)
y = x.t().contiguous()
view_type = torch.int8 if x.element_size() == 1 else torch.int16
self.assertEqual(x.t().view(view_type), y.view(view_type))
y = x.t().clone()
self.assertEqual(x.t().view(view_type), y.view(view_type))

def test_cat(self, device):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
view_type = torch.int8 if bits_type.itemsize == 1 else torch.int16
x_int = torch.randint(100, (512, 512), dtype=view_type, device=device)
x = x_int.view(bits_type)
y_int = torch.randint(100, (512, 512), dtype=view_type, device=device)
y = y_int.view(bits_type)
for dim, transpose in itertools.product(range(x_int.ndim), (True, False)):
y_ref = y_int.t() if transpose else y_int
y_b = y.t() if transpose else y
z_ref = torch.cat([x_int, y_ref], dim=dim)
z = torch.cat([x, y_b], dim=dim)
self.assertEqual(z_ref, z.view(view_type))


def test_subclass(self):
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
s = Int16Tensor(t)
s = s + 1 - 1
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16)))

instantiate_device_type_tests(TestBits, globals())


if __name__ == '__main__':
run_tests()
9 changes: 1 addition & 8 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,7 @@
logging.warning(e)

# Experimental functionality
try:
from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
except ImportError as e:
logging.warning(e)
from quantization.core.experimental.test_bits import TestBits # noqa: F401
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
except ImportError as e:
Expand Down

0 comments on commit ccd3a73

Please sign in to comment.