diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index a1adea6cd172..a5674cf880e3 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1537,6 +1537,53 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) { Ops.LHS, Ops.RHS); } +// Helper to apply OpenCL-style shift masking. It handles both vector and scalar +// RHS. According to OpenCL 6.3j, shift values are effectively modulo word size +// of LHS. +static mlir::Value maskShiftAmount(CIRGenFunction &CGF, + CIRGenBuilderTy &Builder, + const BinOpInfo &Ops) { + // LHS determines bit width + unsigned bitWidth = [&]() { + if (auto vecTy = mlir::dyn_cast(Ops.LHS.getType())) { + auto elemTy = mlir::cast(vecTy.getElementType()); + return elemTy.getWidth(); + } + auto lhsTy = mlir::cast(Ops.LHS.getType()); + return lhsTy.getWidth(); + }(); + + mlir::Location loc = Ops.RHS.getLoc(); + + // Vector RHS: perform element-wise masking + if (auto vecTy = mlir::dyn_cast(Ops.RHS.getType())) { + // Create scalar bit-width constant of element integer type + auto elemTy = mlir::cast(vecTy.getElementType()); + mlir::Value bitMaskConstElem = + Builder.getConstInt(loc, elemTy, bitWidth - 1); + + // Splat bit-width constant to a vector with same shape as RHS + mlir::Value bitWidthConstVec = + cir::VecSplatOp::create(Builder, loc, vecTy, bitMaskConstElem); + + // Make RHS a vector if it isn't one already (splat scalar to vector) + mlir::Value rhsVec = Ops.RHS; + if (!mlir::isa(rhsVec.getType())) + rhsVec = cir::VecSplatOp::create(Builder, loc, vecTy, rhsVec); + + // Compute element-wise: rhs % bitWidth ==> rhs & (bitWidth - 1) + return cir::BinOp::create(Builder, loc, vecTy, cir::BinOpKind::And, rhsVec, + bitWidthConstVec); + } + + // Scalar RHS: simple integer mask + auto rhsTy = mlir::cast(Ops.RHS.getType()); + mlir::Value mask = Builder.getConstInt(loc, rhsTy, bitWidth - 1); + return cir::BinOp::create(Builder, loc, rhsTy, cir::BinOpKind::And, Ops.RHS, + mask); +} + +// Emit bitwise left-shift mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) { // TODO: This misses out on the sanitizer check below. if (Ops.isFixedPointOp()) @@ -1557,18 +1604,27 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) { bool SanitizeExponent = CGF.SanOpts.has(SanitizerKind::ShiftExponent); // OpenCL 6.3j: shift values are effectively % word size of LHS. - if (CGF.getLangOpts().OpenCL) - llvm_unreachable("NYI"); - else if ((SanitizeBase || SanitizeExponent) && - mlir::isa(Ops.LHS.getType())) { - llvm_unreachable("NYI"); + if (CGF.getLangOpts().OpenCL) { + // Mask RHS according to OpenCL rule + mlir::Value rhsMasked = maskShiftAmount(CGF, Builder, Ops); + + return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), + CGF.convertType(Ops.FullType), Ops.LHS, + rhsMasked, CGF.getBuilder().getUnitAttr()); } + // Sanitizer handling NYI (signed base, unsigned base, or exponent). + if ((SanitizeBase || SanitizeExponent) && + mlir::isa(Ops.LHS.getType())) + llvm_unreachable("NYI"); + + // Default case: emit simple shift return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.convertType(Ops.FullType), Ops.LHS, Ops.RHS, CGF.getBuilder().getUnitAttr()); } +// Emit bitwise right-shift mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &Ops) { // TODO: This misses out on the sanitizer check below. if (Ops.isFixedPointOp()) @@ -1579,15 +1635,22 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &Ops) { // promote or truncate the RHS to the same size as the LHS. // OpenCL 6.3j: shift values are effectively % word size of LHS. - if (CGF.getLangOpts().OpenCL) - llvm_unreachable("NYI"); - else if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) && - mlir::isa(Ops.LHS.getType())) { - llvm_unreachable("NYI"); + if (CGF.getLangOpts().OpenCL) { + // Mask RHS as per OpenCL shift rule + mlir::Value rhsMasked = maskShiftAmount(CGF, Builder, Ops); + + return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), + CGF.convertType(Ops.FullType), Ops.LHS, + rhsMasked); } - // Note that we don't need to distinguish unsigned treatment at this - // point since it will be handled later by LLVM lowering. + // Sanitizer handling placeholder for right-shift. + if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) && + mlir::isa(Ops.LHS.getType())) + llvm_unreachable("NYI"); + + // No special handling needed for signed/unsigned distinction at CIR level. + // LLVM lowering will take care of semantics differences. return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.convertType(Ops.FullType), Ops.LHS, Ops.RHS); } diff --git a/clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl b/clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl new file mode 100644 index 000000000000..0d4d11244526 --- /dev/null +++ b/clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl @@ -0,0 +1,32 @@ +// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-cir -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=CIR + +// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-llvm -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM + +// RUN: %clang_cc1 %s -cl-std=CL2.0 -emit-llvm -triple spirv64-unknown-unknown -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll --check-prefix=OG-LLVM + +typedef __attribute__(( ext_vector_type(2) )) int int2; + +// CIR: %[[LHS:.*]] = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector +// CIR: %[[WIDTH:.*]] = cir.const #cir.const_vector<[#cir.int<31> : !s32i, #cir.int<31> : !s32i]> : !cir.vector +// CIR: %[[MASK:.*]] = cir.binop(and, %[[LHS]], %[[WIDTH]]) : !cir.vector +// CIR: cir.shift(right, %{{.*}} : !cir.vector, %[[MASK]] : !cir.vector) -> !cir.vector +// LLVM: ashr <2 x i32> %{{.*}}, splat (i32 3) +// OG-LLVM: ashr <2 x i32> %x, splat (i32 3) +int2 shr(int2 x) +{ + return x >> 3; +} + +// CIR: %[[LHS:.*]] = cir.const #cir.int<5> : !s32i +// CIR: %[[WIDTH:.*]] = cir.const #cir.int<31> : !s32i +// CIR: %[[MASK:.*]] = cir.binop(and, %[[LHS]], %[[WIDTH]]) : !s32i +// CIR: cir.shift(left, %{{.*}} : !s32i, %[[MASK]] : !s32i) -> !s32i +// LLVM: shl i16 %{{.*}}, 5 +// OG-LLVM: shl i16 %x, 5 +short shl(short x) +{ + return x << 5; +} \ No newline at end of file