From 8f3526d616f0a3a5e0ca12d074f72de5c57e30f6 Mon Sep 17 00:00:00 2001 From: Mahmood Yassin Date: Mon, 24 Nov 2025 14:55:25 +0200 Subject: [PATCH] [CIR] Lower element-wise shift operations for OpenCL Implement OpenCL 6.3j shift semantics for both left and right shifts in CIR. The shift amount is now masked by the bit width of the LHS operand, handling both scalar and vector RHS operands. --- clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 87 ++++++++++++++++--- clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl | 32 +++++++ 2 files changed, 107 insertions(+), 12 deletions(-) create mode 100644 clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index a1adea6cd172..a5674cf880e3 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1537,6 +1537,53 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) { Ops.LHS, Ops.RHS); } +// Helper to apply OpenCL-style shift masking. It handles both vector and scalar +// RHS. According to OpenCL 6.3j, shift values are effectively modulo word size +// of LHS. +static mlir::Value maskShiftAmount(CIRGenFunction &CGF, + CIRGenBuilderTy &Builder, + const BinOpInfo &Ops) { + // LHS determines bit width + unsigned bitWidth = [&]() { + if (auto vecTy = mlir::dyn_cast(Ops.LHS.getType())) { + auto elemTy = mlir::cast(vecTy.getElementType()); + return elemTy.getWidth(); + } + auto lhsTy = mlir::cast(Ops.LHS.getType()); + return lhsTy.getWidth(); + }(); + + mlir::Location loc = Ops.RHS.getLoc(); + + // Vector RHS: perform element-wise masking + if (auto vecTy = mlir::dyn_cast(Ops.RHS.getType())) { + // Create scalar bit-width constant of element integer type + auto elemTy = mlir::cast(vecTy.getElementType()); + mlir::Value bitMaskConstElem = + Builder.getConstInt(loc, elemTy, bitWidth - 1); + + // Splat bit-width constant to a vector with same shape as RHS + mlir::Value bitWidthConstVec = + cir::VecSplatOp::create(Builder, loc, vecTy, bitMaskConstElem); + + // Make RHS a vector if it isn't one already (splat scalar to vector) + mlir::Value rhsVec = Ops.RHS; + if (!mlir::isa(rhsVec.getType())) + rhsVec = cir::VecSplatOp::create(Builder, loc, vecTy, rhsVec); + + // Compute element-wise: rhs % bitWidth ==> rhs & (bitWidth - 1) + return cir::BinOp::create(Builder, loc, vecTy, cir::BinOpKind::And, rhsVec, + bitWidthConstVec); + } + + // Scalar RHS: simple integer mask + auto rhsTy = mlir::cast(Ops.RHS.getType()); + mlir::Value mask = Builder.getConstInt(loc, rhsTy, bitWidth - 1); + return cir::BinOp::create(Builder, loc, rhsTy, cir::BinOpKind::And, Ops.RHS, + mask); +} + +// Emit bitwise left-shift mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) { // TODO: This misses out on the sanitizer check below. if (Ops.isFixedPointOp()) @@ -1557,18 +1604,27 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) { bool SanitizeExponent = CGF.SanOpts.has(SanitizerKind::ShiftExponent); // OpenCL 6.3j: shift values are effectively % word size of LHS. - if (CGF.getLangOpts().OpenCL) - llvm_unreachable("NYI"); - else if ((SanitizeBase || SanitizeExponent) && - mlir::isa(Ops.LHS.getType())) { - llvm_unreachable("NYI"); + if (CGF.getLangOpts().OpenCL) { + // Mask RHS according to OpenCL rule + mlir::Value rhsMasked = maskShiftAmount(CGF, Builder, Ops); + + return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), + CGF.convertType(Ops.FullType), Ops.LHS, + rhsMasked, CGF.getBuilder().getUnitAttr()); } + // Sanitizer handling NYI (signed base, unsigned base, or exponent). + if ((SanitizeBase || SanitizeExponent) && + mlir::isa(Ops.LHS.getType())) + llvm_unreachable("NYI"); + + // Default case: emit simple shift return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.convertType(Ops.FullType), Ops.LHS, Ops.RHS, CGF.getBuilder().getUnitAttr()); } +// Emit bitwise right-shift mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &Ops) { // TODO: This misses out on the sanitizer check below. if (Ops.isFixedPointOp()) @@ -1579,15 +1635,22 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &Ops) { // promote or truncate the RHS to the same size as the LHS. // OpenCL 6.3j: shift values are effectively % word size of LHS. - if (CGF.getLangOpts().OpenCL) - llvm_unreachable("NYI"); - else if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) && - mlir::isa(Ops.LHS.getType())) { - llvm_unreachable("NYI"); + if (CGF.getLangOpts().OpenCL) { + // Mask RHS as per OpenCL shift rule + mlir::Value rhsMasked = maskShiftAmount(CGF, Builder, Ops); + + return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), + CGF.convertType(Ops.FullType), Ops.LHS, + rhsMasked); } - // Note that we don't need to distinguish unsigned treatment at this - // point since it will be handled later by LLVM lowering. + // Sanitizer handling placeholder for right-shift. + if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) && + mlir::isa(Ops.LHS.getType())) + llvm_unreachable("NYI"); + + // No special handling needed for signed/unsigned distinction at CIR level. + // LLVM lowering will take care of semantics differences. return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.convertType(Ops.FullType), Ops.LHS, Ops.RHS); } diff --git a/clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl b/clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl new file mode 100644 index 000000000000..0d4d11244526 --- /dev/null +++ b/clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl @@ -0,0 +1,32 @@ +// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-cir -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=CIR + +// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-llvm -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM + +// RUN: %clang_cc1 %s -cl-std=CL2.0 -emit-llvm -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=OG-LLVM + +typedef __attribute__(( ext_vector_type(2) )) int int2; + +// CIR: %[[LHS:.*]] = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector +// CIR: %[[WIDTH:.*]] = cir.const #cir.const_vector<[#cir.int<31> : !s32i, #cir.int<31> : !s32i]> : !cir.vector +// CIR: %[[MASK:.*]] = cir.binop(and, %[[LHS]], %[[WIDTH]]) : !cir.vector +// CIR: cir.shift(right, %{{.*}} : !cir.vector, %[[MASK]] : !cir.vector) -> !cir.vector +// LLVM: ashr <2 x i32> %{{.*}}, splat (i32 3) +// OG-LLVM: ashr <2 x i32> %x, splat (i32 3) +int2 shr(int2 x) +{ + return x >> 3; +} + +// CIR: %[[LHS:.*]] = cir.const #cir.int<5> : !s32i +// CIR: %[[WIDTH:.*]] = cir.const #cir.int<31> : !s32i +// CIR: %[[MASK:.*]] = cir.binop(and, %[[LHS]], %[[WIDTH]]) : !s32i +// CIR: cir.shift(left, %{{.*}} : !s32i, %[[MASK]] : !s32i) -> !s32i +// LLVM: shl i16 %{{.*}}, 5 +// OG-LLVM: shl i16 %x, 5 +short shl(short x) +{ + return x << 5; +} \ No newline at end of file