Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 75 additions & 12 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,53 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
Ops.LHS, Ops.RHS);
}

// Helper to apply OpenCL-style shift masking. It handles both vector and scalar
// RHS. According to OpenCL 6.3j, shift values are effectively modulo word size
// of LHS.
static mlir::Value maskShiftAmount(CIRGenFunction &CGF,
CIRGenBuilderTy &Builder,
const BinOpInfo &Ops) {
// LHS determines bit width
unsigned bitWidth = [&]() {
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(Ops.LHS.getType())) {
auto elemTy = mlir::cast<cir::IntType>(vecTy.getElementType());
return elemTy.getWidth();
}
auto lhsTy = mlir::cast<cir::IntType>(Ops.LHS.getType());
return lhsTy.getWidth();
}();

mlir::Location loc = Ops.RHS.getLoc();

// Vector RHS: perform element-wise masking
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(Ops.RHS.getType())) {
// Create scalar bit-width constant of element integer type
auto elemTy = mlir::cast<cir::IntType>(vecTy.getElementType());
mlir::Value bitMaskConstElem =
Builder.getConstInt(loc, elemTy, bitWidth - 1);

// Splat bit-width constant to a vector with same shape as RHS
mlir::Value bitWidthConstVec =
cir::VecSplatOp::create(Builder, loc, vecTy, bitMaskConstElem);

// Make RHS a vector if it isn't one already (splat scalar to vector)
mlir::Value rhsVec = Ops.RHS;
if (!mlir::isa<cir::VectorType>(rhsVec.getType()))
rhsVec = cir::VecSplatOp::create(Builder, loc, vecTy, rhsVec);

// Compute element-wise: rhs % bitWidth ==> rhs & (bitWidth - 1)
return cir::BinOp::create(Builder, loc, vecTy, cir::BinOpKind::And, rhsVec,
bitWidthConstVec);
}

// Scalar RHS: simple integer mask
auto rhsTy = mlir::cast<cir::IntType>(Ops.RHS.getType());
mlir::Value mask = Builder.getConstInt(loc, rhsTy, bitWidth - 1);
return cir::BinOp::create(Builder, loc, rhsTy, cir::BinOpKind::And, Ops.RHS,
mask);
}

// Emit bitwise left-shift
mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) {
// TODO: This misses out on the sanitizer check below.
if (Ops.isFixedPointOp())
Expand All @@ -1557,18 +1604,27 @@ mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) {
bool SanitizeExponent = CGF.SanOpts.has(SanitizerKind::ShiftExponent);

// OpenCL 6.3j: shift values are effectively % word size of LHS.
if (CGF.getLangOpts().OpenCL)
llvm_unreachable("NYI");
else if ((SanitizeBase || SanitizeExponent) &&
mlir::isa<cir::IntType>(Ops.LHS.getType())) {
llvm_unreachable("NYI");
if (CGF.getLangOpts().OpenCL) {
// Mask RHS according to OpenCL rule
mlir::Value rhsMasked = maskShiftAmount(CGF, Builder, Ops);

return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc),
CGF.convertType(Ops.FullType), Ops.LHS,
rhsMasked, CGF.getBuilder().getUnitAttr());
}

// Sanitizer handling NYI (signed base, unsigned base, or exponent).
if ((SanitizeBase || SanitizeExponent) &&
mlir::isa<cir::IntType>(Ops.LHS.getType()))
llvm_unreachable("NYI");

// Default case: emit simple shift
return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc),
CGF.convertType(Ops.FullType), Ops.LHS, Ops.RHS,
CGF.getBuilder().getUnitAttr());
}

// Emit bitwise right-shift
mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &Ops) {
// TODO: This misses out on the sanitizer check below.
if (Ops.isFixedPointOp())
Expand All @@ -1579,15 +1635,22 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &Ops) {
// promote or truncate the RHS to the same size as the LHS.

// OpenCL 6.3j: shift values are effectively % word size of LHS.
if (CGF.getLangOpts().OpenCL)
llvm_unreachable("NYI");
else if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) &&
mlir::isa<cir::IntType>(Ops.LHS.getType())) {
llvm_unreachable("NYI");
if (CGF.getLangOpts().OpenCL) {
// Mask RHS as per OpenCL shift rule
mlir::Value rhsMasked = maskShiftAmount(CGF, Builder, Ops);

return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc),
CGF.convertType(Ops.FullType), Ops.LHS,
rhsMasked);
}

// Note that we don't need to distinguish unsigned treatment at this
// point since it will be handled later by LLVM lowering.
// Sanitizer handling placeholder for right-shift.
if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) &&
mlir::isa<cir::IntType>(Ops.LHS.getType()))
llvm_unreachable("NYI");

// No special handling needed for signed/unsigned distinction at CIR level.
// LLVM lowering will take care of semantics differences.
return cir::ShiftOp::create(Builder, CGF.getLoc(Ops.Loc),
CGF.convertType(Ops.FullType), Ops.LHS, Ops.RHS);
}
Expand Down
32 changes: 32 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/elemwise-ops.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-cir -triple spirv64-unknown-unknown -o %t.ll
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=CIR

// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-llvm -triple spirv64-unknown-unknown -o %t.ll
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM

// RUN: %clang_cc1 %s -cl-std=CL2.0 -emit-llvm -triple spirv64-unknown-unknown -o %t.ll
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=OG-LLVM

typedef __attribute__(( ext_vector_type(2) )) int int2;

// CIR: %[[LHS:.*]] = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector<!s32i x 2>
// CIR: %[[WIDTH:.*]] = cir.const #cir.const_vector<[#cir.int<31> : !s32i, #cir.int<31> : !s32i]> : !cir.vector<!s32i x 2>
// CIR: %[[MASK:.*]] = cir.binop(and, %[[LHS]], %[[WIDTH]]) : !cir.vector<!s32i x 2>
// CIR: cir.shift(right, %{{.*}} : !cir.vector<!s32i x 2>, %[[MASK]] : !cir.vector<!s32i x 2>) -> !cir.vector<!s32i x 2>
// LLVM: ashr <2 x i32> %{{.*}}, splat (i32 3)
// OG-LLVM: ashr <2 x i32> %x, splat (i32 3)
int2 shr(int2 x)
{
return x >> 3;
}

// CIR: %[[LHS:.*]] = cir.const #cir.int<5> : !s32i
// CIR: %[[WIDTH:.*]] = cir.const #cir.int<31> : !s32i
// CIR: %[[MASK:.*]] = cir.binop(and, %[[LHS]], %[[WIDTH]]) : !s32i
// CIR: cir.shift(left, %{{.*}} : !s32i, %[[MASK]] : !s32i) -> !s32i
// LLVM: shl i16 %{{.*}}, 5
// OG-LLVM: shl i16 %x, 5
short shl(short x)
{
return x << 5;
}