Skip to content

Commit

Permalink
[Matrix] Implement C-style explicit type conversions for matrix types.
Browse files Browse the repository at this point in the history
This implements C-style type conversions for matrix types, as specified
in clang/docs/MatrixTypes.rst.

Fixes PR47141.

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D99037
  • Loading branch information
SaurabhJha authored and fhahn committed Apr 10, 2021
1 parent 471ae42 commit 71ab6c9
Show file tree
Hide file tree
Showing 18 changed files with 368 additions and 13 deletions.
3 changes: 3 additions & 0 deletions clang/include/clang/AST/OperationKinds.def
Expand Up @@ -181,6 +181,9 @@ CAST_OPERATION(PointerToBoolean)
/// (void) malloc(2048)
CAST_OPERATION(ToVoid)

/// CK_MatrixCast - A cast between matrix types of the same dimensions.
CAST_OPERATION(MatrixCast)

/// CK_VectorSplat - A conversion from an arithmetic type to a
/// vector of that element type. Fills all elements ("splats") with
/// the source value.
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/Stmt.h
Expand Up @@ -518,7 +518,7 @@ class alignas(void *) Stmt {

unsigned : NumExprBits;

unsigned Kind : 6;
unsigned Kind : 7;
unsigned PartOfExplicitCast : 1; // Only set for ImplicitCastExpr.

/// True if the call expression has some floating-point features.
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Expand Up @@ -8580,6 +8580,12 @@ let CategoryName = "Inline Assembly Issue" in {

let CategoryName = "Semantic Issue" in {

def err_invalid_conversion_between_matrixes : Error<
"conversion between matrix types%diff{ $ and $|}0,1 of different size is not allowed">;

def err_invalid_conversion_between_matrix_and_type : Error<
"conversion between matrix type %0 and incompatible type %1 is not allowed">;

def err_invalid_conversion_between_vectors : Error<
"invalid conversion between vector type%diff{ $ and $|}0,1 of different "
"size">;
Expand Down
9 changes: 9 additions & 0 deletions clang/include/clang/Sema/Sema.h
Expand Up @@ -11660,6 +11660,8 @@ class Sema final {

bool isValidSveBitcast(QualType srcType, QualType destType);

bool areMatrixTypesOfTheSameDimension(QualType srcTy, QualType destTy);

bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
bool isLaxVectorConversion(QualType srcType, QualType destType);

Expand Down Expand Up @@ -11718,6 +11720,13 @@ class Sema final {
ExprResult checkUnknownAnyArg(SourceLocation callLoc,
Expr *result, QualType &paramType);

// CheckMatrixCast - Check type constraints for matrix casts.
// We allow casting between matrixes of the same dimensions i.e. when they
// have the same number of rows and column. Returns true if the cast is
// invalid.
bool CheckMatrixCast(SourceRange R, QualType DestTy, QualType SrcTy,
CastKind &Kind);

// CheckVectorCast - check type constraints for vectors.
// Since vectors are an extension, there are no C standard reference for this.
// We allow casting between vectors and integer datatypes of the same size.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Expand Up @@ -1708,6 +1708,7 @@ bool CastExpr::CastConsistency() const {
case CK_FixedPointCast:
case CK_FixedPointToIntegral:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
assert(!getType()->isBooleanType() && "unheralded conversion to bool");
goto CheckNoBasePath;

Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Expand Up @@ -13183,6 +13183,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_FixedPointToFloating:
case CK_FixedPointCast:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
llvm_unreachable("invalid cast kind for integral value");

case CK_BitCast:
Expand Down Expand Up @@ -13923,6 +13924,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_FixedPointToBoolean:
case CK_FixedPointToIntegral:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_LValueToRValue:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Expand Up @@ -4645,6 +4645,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_FixedPointToBoolean:
case CK_FixedPointToIntegral:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");

case CK_Dependent:
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CodeGen/CGExprAgg.cpp
Expand Up @@ -901,6 +901,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_CopyAndAutoreleaseBlockObject:
case CK_BuiltinFnToFnPtr:
case CK_ZeroToOCLOpaqueType:
case CK_MatrixCast:

case CK_IntToOCLSampler:
case CK_FloatingToFixedPoint:
Expand Down Expand Up @@ -1422,6 +1423,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_PointerToIntegral:
// Language extensions.
case CK_VectorSplat:
case CK_MatrixCast:
case CK_NonAtomicToAtomic:
case CK_AtomicToNonAtomic:
return true;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprComplex.cpp
Expand Up @@ -533,6 +533,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_FixedPointToBoolean:
case CK_FixedPointToIntegral:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Expand Up @@ -1170,6 +1170,7 @@ class ConstExprEmitter :
case CK_FixedPointToIntegral:
case CK_IntegralToFixedPoint:
case CK_ZeroToOCLOpaqueType:
case CK_MatrixCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
Expand Down
46 changes: 38 additions & 8 deletions clang/lib/CodeGen/CGExprScalar.cpp
Expand Up @@ -1198,27 +1198,50 @@ Value *ScalarExprEmitter::EmitScalarCast(Value *Src, QualType SrcType,
QualType DstType, llvm::Type *SrcTy,
llvm::Type *DstTy,
ScalarConversionOpts Opts) {
if (isa<llvm::IntegerType>(SrcTy)) {
bool InputSigned = SrcType->isSignedIntegerOrEnumerationType();
if (SrcType->isBooleanType() && Opts.TreatBooleanAsSigned) {
// The Element types determine the type of cast to perform.
llvm::Type *SrcElementTy;
llvm::Type *DstElementTy;
QualType SrcElementType;
QualType DstElementType;
if (SrcType->isMatrixType() && DstType->isMatrixType()) {
// Allow bitcast between matrixes of the same size.
if (SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits())
return Builder.CreateBitCast(Src, DstTy, "conv");

SrcElementTy = cast<llvm::VectorType>(SrcTy)->getElementType();
DstElementTy = cast<llvm::VectorType>(DstTy)->getElementType();
SrcElementType = SrcType->castAs<MatrixType>()->getElementType();
DstElementType = DstType->castAs<MatrixType>()->getElementType();
} else {
assert(!SrcType->isMatrixType() && !DstType->isMatrixType() &&
"cannot cast between matrix and non-matrix types");
SrcElementTy = SrcTy;
DstElementTy = DstTy;
SrcElementType = SrcType;
DstElementType = DstType;
}

if (isa<llvm::IntegerType>(SrcElementTy)) {
bool InputSigned = SrcElementType->isSignedIntegerOrEnumerationType();
if (SrcElementType->isBooleanType() && Opts.TreatBooleanAsSigned) {
InputSigned = true;
}

if (isa<llvm::IntegerType>(DstTy))
if (isa<llvm::IntegerType>(DstElementTy))
return Builder.CreateIntCast(Src, DstTy, InputSigned, "conv");
if (InputSigned)
return Builder.CreateSIToFP(Src, DstTy, "conv");
return Builder.CreateUIToFP(Src, DstTy, "conv");
}

if (isa<llvm::IntegerType>(DstTy)) {
assert(SrcTy->isFloatingPointTy() && "Unknown real conversion");
if (DstType->isSignedIntegerOrEnumerationType())
if (isa<llvm::IntegerType>(DstElementTy)) {
assert(SrcElementTy->isFloatingPointTy() && "Unknown real conversion");
if (DstElementType->isSignedIntegerOrEnumerationType())
return Builder.CreateFPToSI(Src, DstTy, "conv");
return Builder.CreateFPToUI(Src, DstTy, "conv");
}

if (DstTy->getTypeID() < SrcTy->getTypeID())
if (DstElementTy->getTypeID() < SrcElementTy->getTypeID())
return Builder.CreateFPTrunc(Src, DstTy, "conv");
return Builder.CreateFPExt(Src, DstTy, "conv");
}
Expand Down Expand Up @@ -1350,6 +1373,9 @@ Value *ScalarExprEmitter::EmitScalarConversion(Value *Src, QualType SrcType,
return Builder.CreateVectorSplat(NumElements, Src, "splat");
}

if (SrcType->isMatrixType() && DstType->isMatrixType())
return EmitScalarCast(Src, SrcType, DstType, SrcTy, DstTy, Opts);

if (isa<llvm::VectorType>(SrcTy) || isa<llvm::VectorType>(DstTy)) {
// Allow bitcast from vector to integer/fp of the same size.
unsigned SrcSize = SrcTy->getPrimitiveSizeInBits();
Expand Down Expand Up @@ -2238,6 +2264,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
CGF.EmitIgnoredExpr(E);
return nullptr;
}
case CK_MatrixCast: {
return EmitScalarConversion(Visit(E), E->getType(), DestTy,
CE->getExprLoc());
}
case CK_VectorSplat: {
llvm::Type *DstTy = ConvertType(DestTy);
Value *Elt = Visit(const_cast<Expr*>(E));
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Expand Up @@ -1080,6 +1080,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
case CK_BuiltinFnToFnPtr:
case CK_ZeroToOCLOpaqueType:
case CK_IntToOCLSampler:
case CK_MatrixCast:
return false;

case CK_BooleanToSignedIntegral:
Expand Down
16 changes: 12 additions & 4 deletions clang/lib/Sema/SemaCast.cpp
Expand Up @@ -2859,7 +2859,8 @@ void CastOperation::CheckCStyleCast() {
return;
}

if (!DestType->isScalarType() && !DestType->isVectorType()) {
if (!DestType->isScalarType() && !DestType->isVectorType() &&
!DestType->isMatrixType()) {
const RecordType *DestRecordTy = DestType->getAs<RecordType>();

if (DestRecordTy && Self.Context.hasSameUnqualifiedType(DestType, SrcType)){
Expand Down Expand Up @@ -2910,10 +2911,11 @@ void CastOperation::CheckCStyleCast() {
return;
}

// The type we're casting to is known to be a scalar or vector.
// The type we're casting to is known to be a scalar, a vector, or a matrix.

// Require the operand to be a scalar or vector.
if (!SrcType->isScalarType() && !SrcType->isVectorType()) {
// Require the operand to be a scalar, a vector, or a matrix.
if (!SrcType->isScalarType() && !SrcType->isVectorType() &&
!SrcType->isMatrixType()) {
Self.Diag(SrcExpr.get()->getExprLoc(),
diag::err_typecheck_expect_scalar_operand)
<< SrcType << SrcExpr.get()->getSourceRange();
Expand All @@ -2926,6 +2928,12 @@ void CastOperation::CheckCStyleCast() {
return;
}

if (DestType->getAs<MatrixType>() || SrcType->getAs<MatrixType>()) {
if (Self.CheckMatrixCast(OpRange, DestType, SrcType, Kind))
SrcExpr = ExprError();
return;
}

if (const VectorType *DestVecTy = DestType->getAs<VectorType>()) {
if (DestVecTy->getVectorKind() == VectorType::AltiVecVector &&
(SrcType->isIntegerType() || SrcType->isFloatingType())) {
Expand Down
34 changes: 34 additions & 0 deletions clang/lib/Sema/SemaExpr.cpp
Expand Up @@ -7345,6 +7345,19 @@ bool Sema::isValidSveBitcast(QualType srcTy, QualType destTy) {
ValidScalableConversion(destTy, srcTy);
}

/// Are the two types matrix types and do they have the same dimensions i.e.
/// do they have the same number of rows and the same number of columns?
bool Sema::areMatrixTypesOfTheSameDimension(QualType srcTy, QualType destTy) {
if (!destTy->isMatrixType() || !srcTy->isMatrixType())
return false;

const ConstantMatrixType *matSrcType = srcTy->getAs<ConstantMatrixType>();
const ConstantMatrixType *matDestType = destTy->getAs<ConstantMatrixType>();

return matSrcType->getNumRows() == matDestType->getNumRows() &&
matSrcType->getNumColumns() == matDestType->getNumColumns();
}

/// Are the two types lax-compatible vector types? That is, given
/// that one of them is a vector, do they have equal storage sizes,
/// where the storage size is the number of elements times the element
Expand Down Expand Up @@ -7407,6 +7420,27 @@ bool Sema::isLaxVectorConversion(QualType srcTy, QualType destTy) {
return areLaxCompatibleVectorTypes(srcTy, destTy);
}

bool Sema::CheckMatrixCast(SourceRange R, QualType DestTy, QualType SrcTy,
CastKind &Kind) {
if (SrcTy->isMatrixType() && DestTy->isMatrixType()) {
if (!areMatrixTypesOfTheSameDimension(SrcTy, DestTy)) {
return Diag(R.getBegin(), diag::err_invalid_conversion_between_matrixes)
<< DestTy << SrcTy << R;
}
} else if (SrcTy->isMatrixType()) {
return Diag(R.getBegin(),
diag::err_invalid_conversion_between_matrix_and_type)
<< SrcTy << DestTy << R;
} else if (DestTy->isMatrixType()) {
return Diag(R.getBegin(),
diag::err_invalid_conversion_between_matrix_and_type)
<< DestTy << SrcTy << R;
}

Kind = CK_MatrixCast;
return false;
}

bool Sema::CheckVectorCast(SourceRange R, QualType VectorTy, QualType Ty,
CastKind &Kind) {
assert(VectorTy->isVectorType() && "Not a vector type!");
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
Expand Up @@ -543,6 +543,10 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
state = handleLVectorSplat(state, LCtx, CastE, Bldr, Pred);
continue;
}
case CK_MatrixCast: {
// TODO: Handle MatrixCast here.
continue;
}
}
}
}
Expand Down

0 comments on commit 71ab6c9

Please sign in to comment.