diff --git a/flang-rt/lib/cuda/memory.cpp b/flang-rt/lib/cuda/memory.cpp index d830580e6a066..78270fef07c36 100644 --- a/flang-rt/lib/cuda/memory.cpp +++ b/flang-rt/lib/cuda/memory.cpp @@ -25,23 +25,22 @@ extern "C" { void *RTDEF(CUFMemAlloc)( std::size_t bytes, unsigned type, const char *sourceFile, int sourceLine) { void *ptr = nullptr; - if (bytes != 0) { - if (type == kMemTypeDevice) { - if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) { - CUDA_REPORT_IF_ERROR( - cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); - } else { - CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes)); - } - } else if (type == kMemTypeManaged || type == kMemTypeUnified) { + bytes = bytes ? bytes : 1; + if (type == kMemTypeDevice) { + if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) { CUDA_REPORT_IF_ERROR( cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); - } else if (type == kMemTypePinned) { - CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes)); } else { - Terminator terminator{sourceFile, sourceLine}; - terminator.Crash("unsupported memory type"); + CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes)); } + } else if (type == kMemTypeManaged || type == kMemTypeUnified) { + CUDA_REPORT_IF_ERROR( + cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); + } else if (type == kMemTypePinned) { + CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes)); + } else { + Terminator terminator{sourceFile, sourceLine}; + terminator.Crash("unsupported memory type"); } return ptr; } diff --git a/flang-rt/unittests/Runtime/CUDA/Memory.cpp b/flang-rt/unittests/Runtime/CUDA/Memory.cpp index f2e17870f7999..c84c54a1376e5 100644 --- a/flang-rt/unittests/Runtime/CUDA/Memory.cpp +++ b/flang-rt/unittests/Runtime/CUDA/Memory.cpp @@ -35,6 +35,12 @@ TEST(MemoryCUFTest, SimpleAllocTramsferFree) { RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__); } +TEST(MemoryCUFTest, AllocZero) { + int *dev = (int *)RTNAME(CUFMemAlloc)(0, kMemTypeDevice, __FILE__, __LINE__); + EXPECT_TRUE(dev != 0); + RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__); +} + static OwningPtr createAllocatable( Fortran::common::TypeCategory tc, int kind, int rank = 1) { return Descriptor::Create(TypeCode{tc, kind}, kind, nullptr, rank, nullptr,