diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index cb9b7f6ec2fd2..f07307fcd2f9d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, << type << " illegal: cannot handle zero-element tensors\n"); return nullptr; } + if (arrayElemCount > std::numeric_limits::max()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot fit tensor into target type\n"); + return nullptr; + } Type arrayElemType = convertScalarType(targetEnv, options, scalarType); if (!arrayElemType) diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir index b69c2d0408d17..65c6e0587129e 100644 --- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -79,3 +79,12 @@ func.func @tensor_2d_empty() -> () { %x = arith.constant dense<> : tensor<2x0xi32> return } + +// Tensors with more than UINT32_MAX elements cannnot fit in a spirv.array. +// Test that they are not lowered. +// CHECK-LABEL: func @very_large_tensor +// CHECK-NEXT: arith.constant dense<1> +func.func @very_large_tensor() -> () { + %x = arith.constant dense<1> : tensor<4294967296xi32> + return +}