Skip to content

Commit

Permalink
fix the return type of multiply_image_and_position
Browse files Browse the repository at this point in the history
for images of type int8, multiply_image_and_position was limited to positions < 256 (and for int16 to positions < 2^16), because the return type was the same as the image type. However int8 images can have positions larger than that because position indices are of type int32. This commit ensures that the return type is at least int32 so that multiply_image_and_position can handle positions larger than the maximum integer for int8 and int16 images.

closes issue tier1::multiply_image_and_position_func() implicitly assumes that the position has the same type as the input. #287
  • Loading branch information
thawn committed Apr 28, 2024
1 parent 106dbb9 commit 3c0de89
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
16 changes: 15 additions & 1 deletion clic/src/tier1/multiply_image_and_position.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@ multiply_image_and_position_func(const Device::Pointer & device,
Array::Pointer dst,
int dimension) -> Array::Pointer
{
tier0::create_like(src, dst);
auto type = src->dtype();
switch (src->dtype())
{
case dType::INT8:
case dType::INT16:
type = dType::INT32;
break;
case dType::UINT8:
case dType::UINT16:
type = dType::UINT32;
break;
default:
break;
}
tier0::create_like(src, dst, type);
const KernelInfo kernel = { "multiply_image_and_position", kernel::multiply_image_and_position };
const ParameterList params = { { "src", src }, { "dst", dst }, { "index", dimension } };
const RangeArray range = { dst->width(), dst->height(), dst->depth() };
Expand Down
29 changes: 29 additions & 0 deletions tests/tier1/test_multiply_image_and_position.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,35 @@ TEST_P(TestMultiplyPixelAndCoord, execute)
}
}

TEST_P(TestMultiplyPixelAndCoord, returnType)
{
std::string param = GetParam();
cle::BackendManager::getInstance().setBackend(param);
auto device = cle::BackendManager::getInstance().getBackend().getDevice("", "all");
device->setWaitToFinish(true);

for (cle::dType type : { cle::dType::INT8, cle::dType::INT16 })
{
auto gpu_input = cle::Array::create(5, 3, 1, 3, type, cle::mType::BUFFER, device);
auto gpu_output = cle::tier1::multiply_image_and_position_func(device, gpu_input, nullptr, 0);
EXPECT_EQ(gpu_output->dtype(), cle::dType::INT32);
}

for (cle::dType type : { cle::dType::UINT8, cle::dType::UINT16 })
{
auto gpu_input = cle::Array::create(5, 3, 1, 3, type, cle::mType::BUFFER, device);
auto gpu_output = cle::tier1::multiply_image_and_position_func(device, gpu_input, nullptr, 0);
EXPECT_EQ(gpu_output->dtype(), cle::dType::UINT32);
}

for (cle::dType type : { cle::dType::FLOAT, cle::dType::UINT32, cle::dType::INT32 })
{
auto gpu_input = cle::Array::create(5, 3, 1, 3, type, cle::mType::BUFFER, device);
auto gpu_output = cle::tier1::multiply_image_and_position_func(device, gpu_input, nullptr, 0);
EXPECT_EQ(gpu_output->dtype(), type);
}
}

std::vector<std::string>
getParameters()
{
Expand Down

0 comments on commit 3c0de89

Please sign in to comment.