152 changes: 0 additions & 152 deletions mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,156 +136,4 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLength
//===----------------------------------------------------------------------===//

LogicalResult NVCooperativeMatrixLengthOp::verify() {
if (!isa<CooperativeMatrixNVType>(getCooperativeMatrixType())) {
return emitOpError(
"type attribute must be a '!spirv.NV.coopmatrix' type, found ")
<< getCooperativeMatrixType() << " instead";
}

return success();
}

//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//

ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
Type ptrType;
Type elementType;
if (parser.parseOperandList(operandInfo, 3) ||
parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
return failure();
}
if (parser.resolveOperands(operandInfo,
{ptrType, strideType, columnMajorType},
parser.getNameLoc(), result.operands)) {
return failure();
}

result.addTypes(elementType);
return success();
}

void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " " << getPointer() << ", " << getStride() << ", "
<< getColumnmajor();
// Print optional memory access attribute.
if (auto memAccess = getMemoryAccess())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
printer << " : " << getPointer().getType() << " as " << getType();
}

static LogicalResult
verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
Type pointeeType = llvm::cast<PointerType>(pointer).getPointeeType();
if (!llvm::isa<ScalarType>(pointeeType) &&
!llvm::isa<VectorType>(pointeeType))
return op->emitError(
"Pointer must point to a scalar or vector type but provided ")
<< pointeeType;
StorageClass storage = llvm::cast<PointerType>(pointer).getStorageClass();
if (storage != StorageClass::Workgroup &&
storage != StorageClass::StorageBuffer &&
storage != StorageClass::PhysicalStorageBuffer)
return op->emitError(
"Pointer storage class must be Workgroup, StorageBuffer or "
"PhysicalStorageBufferEXT but provided ")
<< stringifyStorageClass(storage);
return success();
}

LogicalResult NVCooperativeMatrixLoadOp::verify() {
return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
getResult().getType());
}

//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixStore
//===----------------------------------------------------------------------===//

ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
Type ptrType;
Type elementType;
if (parser.parseOperandList(operandInfo, 4) ||
parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
parser.parseType(ptrType) || parser.parseComma() ||
parser.parseType(elementType)) {
return failure();
}
if (parser.resolveOperands(
operandInfo, {ptrType, elementType, strideType, columnMajorType},
parser.getNameLoc(), result.operands)) {
return failure();
}

return success();
}

void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
<< ", " << getColumnmajor();
// Print optional memory access attribute.
if (auto memAccess = getMemoryAccess())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
}

LogicalResult NVCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
getObject().getType());
}

//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixMulAdd
//===----------------------------------------------------------------------===//

static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) {
if (op.getC().getType() != op.getResult().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = llvm::cast<CooperativeMatrixNVType>(op.getA().getType());
auto typeB = llvm::cast<CooperativeMatrixNVType>(op.getB().getType());
auto typeC = llvm::cast<CooperativeMatrixNVType>(op.getC().getType());
auto typeR = llvm::cast<CooperativeMatrixNVType>(op.getResult().getType());
if (typeA.getRows() != typeR.getRows() ||
typeA.getColumns() != typeB.getRows() ||
typeB.getColumns() != typeR.getColumns())
return op.emitOpError("matrix size must match");
if (typeR.getScope() != typeA.getScope() ||
typeR.getScope() != typeB.getScope() ||
typeR.getScope() != typeC.getScope())
return op.emitOpError("matrix scope must match");
auto elementTypeA = typeA.getElementType();
auto elementTypeB = typeB.getElementType();
if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
llvm::cast<IntegerType>(elementTypeB).getWidth())
return op.emitOpError(
"matrix A and B integer element types must be the same bit width");
} else if (elementTypeA != elementTypeB) {
return op.emitOpError(
"matrix A and B non-integer element types must match");
}
if (typeR.getElementType() != typeC.getElementType())
return op.emitOpError("matrix accumulator element type must match");
return success();
}

LogicalResult NVCooperativeMatrixMulAddOp::verify() {
return verifyCoopMatrixMulAddNV(*this);
}

} // namespace mlir::spirv
46 changes: 3 additions & 43 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,37 +360,6 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}

// nv-cooperative-matrix-type ::=
// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type `,` scope `>`
static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();

SmallVector<int64_t, 2> dims;
SMLoc countLoc = parser.getCurrentLocation();
if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
return Type();

if (dims.size() != 2) {
parser.emitError(countLoc, "expected rows and columns size");
return Type();
}

auto elementTy = parseAndVerifyType(dialect, parser);
if (!elementTy)
return Type();

Scope scope;
if (parser.parseComma() ||
spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();

if (parser.parseGreater())
return Type();
return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
}

// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
// element-type
// `,` layout `,` scope`>`
Expand Down Expand Up @@ -810,8 +779,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseArrayType(*this, parser);
if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
if (keyword == "NV.coopmatrix")
return parseCooperativeMatrixNVType(*this, parser);
if (keyword == "jointmatrix")
return parseJointMatrixType(*this, parser);
if (keyword == "image")
Expand Down Expand Up @@ -917,12 +884,6 @@ static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
<< type.getUse() << ">";
}

static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", " << stringifyScope(type.getScope());
os << ">";
}

static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", "
Expand All @@ -937,10 +898,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {

void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
JointMatrixINTELType, PointerType, RuntimeArrayType, ImageType,
SampledImageType, StructType, MatrixType>(
[&](auto type) { print(type, os); })
.Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType,
RuntimeArrayType, ImageType, SampledImageType, StructType,
MatrixType>([&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}

Expand Down
8 changes: 3 additions & 5 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {

auto coopElementType =
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
spirv::JointMatrixINTELType>(
.Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
[](auto coopType) { return coopType.getElementType(); })
.Default([](Type) { return nullptr; });

Expand Down Expand Up @@ -1677,8 +1676,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
LogicalResult spirv::MatrixTimesScalarOp::verify() {
Type elementType =
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
.Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
spirv::MatrixType>(
.Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
[](auto matrixType) { return matrixType.getElementType(); })
.Default([](Type) { return nullptr; });

Expand Down Expand Up @@ -1817,7 +1815,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
return emitError("result type must be a composite type, but provided ")
<< getType();

if (llvm::isa<spirv::CooperativeMatrixNVType>(cType))
if (llvm::isa<spirv::CooperativeMatrixType>(cType))
return emitError("unsupported composite type ") << cType;
if (llvm::isa<spirv::JointMatrixINTELType>(cType))
return emitError("unsupported composite type ") << cType;
Expand Down
93 changes: 14 additions & 79 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ bool CompositeType::classof(Type type) {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType,
spirv::MatrixType, spirv::RuntimeArrayType,
spirv::StructType>(type);
spirv::JointMatrixINTELType, spirv::MatrixType,
spirv::RuntimeArrayType, spirv::StructType>(type);
}

bool CompositeType::isValid(VectorType type) {
Expand All @@ -108,8 +107,8 @@ bool CompositeType::isValid(VectorType type) {

Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
JointMatrixINTELType, RuntimeArrayType, VectorType>(
.Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType,
RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
Expand All @@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*this)) {
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
}
Expand All @@ -143,16 +142,16 @@ unsigned CompositeType::getNumElements() const {
}

bool CompositeType::hasCompileTimeKnownNumElements() const {
return !llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType,
JointMatrixINTELType, RuntimeArrayType>(*this);
return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
RuntimeArrayType>(*this);
}

void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
.Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
RuntimeArrayType, StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return llvm::cast<ScalarType>(type.getElementType())
Expand All @@ -165,8 +164,8 @@ void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
.Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
RuntimeArrayType, StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
Expand Down Expand Up @@ -267,70 +266,6 @@ void CooperativeMatrixType::getCapabilities(
capabilities.push_back(caps);
}

//===----------------------------------------------------------------------===//
// CooperativeMatrixNVType
//===----------------------------------------------------------------------===//

struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;

static CooperativeMatrixNVTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<CooperativeMatrixNVTypeStorage>())
CooperativeMatrixNVTypeStorage(key);
}

bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, scope, rows, columns);
}

CooperativeMatrixNVTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<2>(key)),
columns(std::get<3>(key)), scope(std::get<1>(key)) {}

Type elementType;
unsigned rows;
unsigned columns;
Scope scope;
};

CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
Scope scope, unsigned rows,
unsigned columns) {
return Base::get(elementType.getContext(), elementType, scope, rows, columns);
}

Type CooperativeMatrixNVType::getElementType() const {
return getImpl()->elementType;
}

Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }

unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }

unsigned CooperativeMatrixNVType::getColumns() const {
return getImpl()->columns;
}

void CooperativeMatrixNVType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
ArrayRef<Extension> ref(exts, std::size(exts));
extensions.push_back(ref);
}

void CooperativeMatrixNVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType())
.getCapabilities(capabilities, storage);
static const Capability caps[] = {Capability::CooperativeMatrixNV};
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
}

//===----------------------------------------------------------------------===//
// JointMatrixType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1312,7 +1247,7 @@ void MatrixType::getCapabilities(
//===----------------------------------------------------------------------===//

void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType, ImageType,
JointMatrixINTELType, MatrixType, PointerType, RuntimeArrayType,
SampledImageType, StructType>();
addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType,
MatrixType, PointerType, RuntimeArrayType, SampledImageType,
StructType>();
}
1 change: 0 additions & 1 deletion mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
return processTypeForwardPointer(operands);
Expand Down
33 changes: 0 additions & 33 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,6 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processArrayType(operands);
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processCooperativeMatrixTypeKHR(operands);
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processCooperativeMatrixTypeNV(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeJointMatrixINTEL:
Expand Down Expand Up @@ -1017,37 +1015,6 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
return success();
}

LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
ArrayRef<uint32_t> operands) {
if (operands.size() != 5) {
return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
"type and row x column parameters");
}

Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrixNV references undefined <id> ")
<< operands[1];
}

std::optional<spirv::Scope> scope =
spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
if (!scope) {
return emitError(
unknownLoc,
"OpTypeCooperativeMatrixNV references undefined scope <id> ")
<< operands[2];
}

unsigned rows = getConstantInt(operands[3]).getInt();
unsigned columns = getConstantInt(operands[4]).getInt();

typeMap[operands[0]] =
spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
return success();
}

LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 6) {
Expand Down
20 changes: 0 additions & 20 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,26 +648,6 @@ LogicalResult Serializer::prepareBasicType(
return success();
}

if (auto cooperativeMatrixType =
dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
llvm::append_values(
operands, elementTypeID,
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
getConstantOp(cooperativeMatrixType.getRows()),
getConstantOp(cooperativeMatrixType.getColumns()));
return success();
}

if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
Expand Down
120 changes: 120 additions & 0 deletions mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,123 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>

// -----

// CHECK-LABEL: func @complex_mul_with_fmf
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
%mul = complex.mul %lhs, %rhs fastmath<nnan,contract> : complex<f32>
return %mul : complex<f32>
}
// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath<nnan,contract> : f32

// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32

// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32

// Handle cases where the "naive" calculation results in NaN values.
// CHECK: %[[REAL_IS_NAN:.*]] = arith.cmpf uno, %[[REAL]], %[[REAL]] : f32
// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG]], %[[IMAG]] : f32
// CHECK: %[[IS_NAN:.*]] = arith.andi %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32

// Case 1. LHS_REAL or LHS_IMAG are infinite.
// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
// CHECK: %[[LHS_IS_INF:.*]] = arith.ori %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
// CHECK: %[[RHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
// CHECK: %[[LHS_REAL1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
// CHECK: %[[LHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL]] : f32
// CHECK: %[[RHS_REAL1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG]] : f32
// CHECK: %[[RHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32

// Case 2. RHS_REAL or RHS_IMAG are infinite.
// CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
// CHECK: %[[RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
// CHECK: %[[RHS_IS_INF:.*]] = arith.ori %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
// CHECK: %[[LHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
// CHECK: %[[RHS_REAL2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
// CHECK: %[[RHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL1]] : f32
// CHECK: %[[LHS_REAL2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
// CHECK: %[[LHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
// CHECK: %[[RECALC:.*]] = arith.ori %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1

// Case 3. One of the pairwise products of left hand side with right hand side
// is infinite.
// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
// CHECK: %[[IS_SPECIAL_CASE:.*]] = arith.ori %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
// CHECK: %[[IS_SPECIAL_CASE1:.*]] = arith.ori %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
// CHECK: %[[IS_SPECIAL_CASE2:.*]] = arith.ori %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[NOT_RECALC:.*]] = arith.xori %[[RECALC]], %[[TRUE]] : i1
// CHECK: %[[IS_SPECIAL_CASE3:.*]] = arith.andi %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL2]] : f32
// CHECK: %[[LHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
// CHECK: %[[LHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL2]] : f32
// CHECK: %[[RHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
// CHECK: %[[RHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
// CHECK: %[[RECALC2:.*]] = arith.ori %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
// CHECK: %[[RECALC3:.*]] = arith.andi %[[IS_NAN]], %[[RECALC2]] : i1

// Recalculate real part.
// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_REAL3]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] fastmath<nnan,contract> : f32
// CHECK: %[[NEW_REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[FINAL_REAL:.*]] = arith.select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32

// Recalculate imag part.
// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] fastmath<nnan,contract> : f32
// CHECK: %[[NEW_IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32

// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
// RUN: mlir-opt --convert-gpu-to-spirv --cse \
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s

module attributes {
Expand Down
194 changes: 0 additions & 194 deletions mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir

This file was deleted.

24 changes: 0 additions & 24 deletions mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,6 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou

// -----

func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
// CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
%0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -238,14 +230,6 @@ func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, M

// -----

func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
%0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
spirv.Return
}

// -----

func.func @f_convert_vector(%arg0 : f32) -> f32 {
// expected-error @+1 {{expected the different bit widths for operand type and result type, but provided 'f32' and 'f32'}}
%0 = spirv.FConvert %arg0 : f32 to f32
Expand All @@ -254,14 +238,6 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {

// -----

func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>) {
// expected-error @+1 {{incompatible operand and result types}}
%0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
spirv.Return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 0 additions & 39 deletions mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
}

// CHECK-LABEL: func @composite_construct_coopmatrix_nv
func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
}

// -----

func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
Expand Down Expand Up @@ -75,22 +68,6 @@ func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32

// -----

func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
}

// -----

func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
%0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
}

// -----

func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> {
// expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32>
Expand Down Expand Up @@ -143,14 +120,6 @@ func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {

// -----

func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 {
// CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
%0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0 : f32
}

// -----

func.func @composite_extract_no_ssa_operand() -> () {
// expected-error @+1 {{expected SSA operand}}
%0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
Expand Down Expand Up @@ -271,14 +240,6 @@ func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f3

// -----

func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> {
// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
%0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>
}

// -----

func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
// expected-error @+1 {{expected at least one index}}
%0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32>
Expand Down
26 changes: 0 additions & 26 deletions mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {

// -----

spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
// expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
%0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.ReturnValue %0 : i32
}

// -----

// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
Expand Down Expand Up @@ -118,24 +110,6 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB

// -----

spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{expected '<'}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
}

// CHECK-LABEL: @matrix_times_scalar_2
spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
}

// CHECK-LABEL: @matrix_transpose_1
Expand Down
177 changes: 0 additions & 177 deletions mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir

This file was deleted.

4 changes: 2 additions & 2 deletions mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -810,15 +810,15 @@ spirv.module Logical GLSL450 {
}

//===----------------------------------------------------------------------===//
// spirv.SpecConstantComposite (spirv.NV.coopmatrix)
// spirv.SpecConstantComposite (spirv.KHR.coopmatrix)
//===----------------------------------------------------------------------===//

// -----

spirv.module Logical GLSL450 {
spirv.SpecConstant @sc1 = 1.5 : f32
// expected-error @+1 {{unsupported composite type}}
spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device>
spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device, MatrixA>
}

//===----------------------------------------------------------------------===//
Expand Down
19 changes: 0 additions & 19 deletions mlir/test/Dialect/SPIRV/IR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -479,25 +479,6 @@ func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup

// -----

//===----------------------------------------------------------------------===//
// NV.CooperativeMatrix
//===----------------------------------------------------------------------===//

// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()

// -----

// expected-error @+1 {{expected ','}}
func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> ()

// -----

// expected-error @+1 {{expected rows and columns size}}
func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> ()

// -----

//===----------------------------------------------------------------------===//
// Matrix
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Target/SPIRV/matrix.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
}

// CHECK-LABEL: @matrix_times_scalar_3
spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
}

// CHECK-LABEL: @matrix_transpose_1
Expand Down
102 changes: 0 additions & 102 deletions mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir

This file was deleted.

98 changes: 34 additions & 64 deletions mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,80 +311,50 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
// Collect list of operations that can be tiled and fused.
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
collectTiledAndFusedOps(rootOp);
auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
outerMostTiledLoop->isAncestor(user);
llvm::SmallDenseMap<Operation *, bool> yielded;
auto isIgnoredUser = [&](Operation *user) {
return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user);
};

// The rest of this method is similar to
// scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp, except that also
// yields replacements for values of the fused producer.

// 1. Tile the consumer.
SmallVector<OpResult> yieldedValuesToOrigValues;
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCFForOp(rewriter, rootOp, options);
if (failed(tilingResult)) {
return rewriter.notifyMatchFailure(rootOp,
"failed to tile base operation");
for (Operation *op : tiledAndFusedOps) {
yielded[op] = llvm::any_of(op->getUsers(), [&](Operation *user) {
return !isIgnoredUser(user);
});
}
yieldedValuesToOrigValues.append(rootOp->result_begin(),
rootOp->result_end());

// 2. Tiling each operation results in generation of slices. The source of
// these slices could be producers that can be fused into the tiled loops by
// computing the slices of these producers in-place. This results in more
// slices created for operands of the "fused producer". This open up more
// opportunities for fusion. Use a worklist to fuse greedily.
auto addCandidateSlices =
[](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
for (Value operand : fusedOp->getOperands())
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
candidates.push_back(sliceOp);
};

std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tilingResult->tiledOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
auto forLoops = llvm::to_vector(llvm::map_range(
tilingResult->loops, [](auto op) { return cast<scf::ForOp>(op); }));
while (!candidates.empty()) {
// Traverse the slices in BFS fashion.
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();

// Materialize the slice of the producer in place.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
if (!fusedProducer)
continue;

// Check if the fused producer has other uses that require the value
// to be yielded from within the tiled loop.
OpResult untiledProducer = fusedProducer->origProducer;
if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
return !isIgnoredUser(user, forLoops.front());
})) {
yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
fusedProducer.value(), forLoops);
yieldedValuesToOrigValues.push_back(untiledProducer);
}
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(options);
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand) {
Operation *owner = originalProducer.getOwner();
return std::make_tuple(true,
yielded.contains(owner) && yielded[owner]);
};
tileAndFuseOptions.setFusionControlFn(controlFn);

// Add more fusion candidates to the worklist.
if (auto fusedProducerOp =
fusedProducer->tiledAndFusedProducer.getDefiningOp())
addCandidateSlices(fusedProducerOp, candidates);
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, rootOp, tileAndFuseOptions);
if (failed(tileAndFuseResult)) {
return rewriter.notifyMatchFailure(
rootOp, "failed to tile and fuse with op as root");
}

scf::ForOp outermostLoop = forLoops.front();
for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
Value replacement = outermostLoop.getResult(index);
for (auto it : tileAndFuseResult->replacements) {
Value origVal = it.first;
Value replacement = it.second;
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
return !isIgnoredUser(use.getOwner(), outermostLoop);
Operation *user = use.getOwner();
return !isIgnoredUser(user) &&
!tileAndFuseResult->loops.front()->isAncestor(user);
});
}

rewriter.eraseOp(rootOp);
filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back());
for (auto tiledAndFusedOp : tileAndFuseResult->tiledAndFusedOps)
if (tiledAndFusedOp->hasAttr(kTransformMarker))
filter.replaceTransformationFilter(rewriter, tiledAndFusedOp);

return success();
}

Expand Down
5 changes: 3 additions & 2 deletions openmp/libomptarget/include/DeviceImage.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
class DeviceImageTy {

std::unique_ptr<llvm::object::OffloadBinary> Binary;
llvm::SmallVector<std::unique_ptr<OffloadEntryTy>> OffloadEntries;

__tgt_bin_desc *BinaryDesc;
__tgt_device_image Image;
Expand All @@ -37,7 +36,9 @@ class DeviceImageTy {
__tgt_device_image &getExecutableImage() { return Image; }
__tgt_bin_desc &getBinaryDesc() { return *BinaryDesc; }

auto entries() { return llvm::make_pointee_range(OffloadEntries); }
auto entries() {
return llvm::make_range(Image.EntriesBegin, Image.EntriesEnd);
}
};

#endif // OMPTARGET_DEVICE_IMAGE_H
8 changes: 4 additions & 4 deletions openmp/libomptarget/include/OffloadEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ class OffloadEntryTy {
const char *getNameAsCStr() const { return OffloadEntry.name; }
__tgt_bin_desc *getBinaryDescription() const;

bool isCTor() { return hasFlags(OMP_DECLARE_TARGET_CTOR); }
bool isDTor() { return hasFlags(OMP_DECLARE_TARGET_DTOR); }
bool isLink() { return hasFlags(OMP_DECLARE_TARGET_LINK); }
bool isCTor() const { return hasFlags(OMP_DECLARE_TARGET_CTOR); }
bool isDTor() const { return hasFlags(OMP_DECLARE_TARGET_DTOR); }
bool isLink() const { return hasFlags(OMP_DECLARE_TARGET_LINK); }

bool hasFlags(OpenMPOffloadingDeclareTargetFlags Flags) {
bool hasFlags(OpenMPOffloadingDeclareTargetFlags Flags) const {
return Flags & OffloadEntry.flags;
}
};
Expand Down
4 changes: 2 additions & 2 deletions openmp/libomptarget/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ struct DeviceTy {
/// }

/// Register \p Entry as an offload entry that is avalable on this device.
void addOffloadEntry(OffloadEntryTy &Entry);
void addOffloadEntry(const OffloadEntryTy &Entry);

/// Print all offload entries to stderr.
void dumpOffloadEntries();
Expand All @@ -170,7 +170,7 @@ struct DeviceTy {

/// All offload entries available on this device.
using DeviceOffloadEntriesMapTy =
llvm::DenseMap<llvm::StringRef, OffloadEntryTy *>;
llvm::DenseMap<llvm::StringRef, OffloadEntryTy>;
ProtectedObj<DeviceOffloadEntriesMapTy> DeviceOffloadEntries;

/// Handler to collect and organize host-2-device mapping information.
Expand Down
4 changes: 0 additions & 4 deletions openmp/libomptarget/src/DeviceImage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ DeviceImageTy::DeviceImageTy(__tgt_bin_desc &BinaryDesc,
__tgt_device_image &TgtDeviceImage)
: BinaryDesc(&BinaryDesc), Image(TgtDeviceImage) {

for (__tgt_offload_entry &Entry :
llvm::make_range(Image.EntriesBegin, Image.EntriesEnd))
OffloadEntries.emplace_back(std::make_unique<OffloadEntryTy>(*this, Entry));

llvm::StringRef ImageStr(
static_cast<char *>(Image.ImageStart),
llvm::omp::target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
Expand Down
4 changes: 2 additions & 2 deletions openmp/libomptarget/src/PluginManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ void PluginAdaptorTy::addOffloadEntries(DeviceImageTy &DI) {
toString(DeviceOrErr.takeError()).c_str());

DeviceTy &Device = *DeviceOrErr;
for (OffloadEntryTy &Entry : DI.entries())
Device.addOffloadEntry(Entry);
for (__tgt_offload_entry &Entry : DI.entries())
Device.addOffloadEntry(OffloadEntryTy(DI, Entry));
}
}

Expand Down
15 changes: 7 additions & 8 deletions openmp/libomptarget/src/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,9 @@ int32_t DeviceTy::destroyEvent(void *Event) {
return OFFLOAD_SUCCESS;
}

void DeviceTy::addOffloadEntry(OffloadEntryTy &Entry) {
void DeviceTy::addOffloadEntry(const OffloadEntryTy &Entry) {
std::lock_guard<decltype(PendingGlobalsMtx)> Lock(PendingGlobalsMtx);
DeviceOffloadEntries.getExclusiveAccessor()->insert(
{Entry.getName(), &Entry});
DeviceOffloadEntries.getExclusiveAccessor()->insert({Entry.getName(), Entry});
if (Entry.isGlobal())
return;

Expand Down Expand Up @@ -329,14 +328,14 @@ void DeviceTy::dumpOffloadEntries() {
fprintf(stderr, "Device %i offload entries:\n", DeviceID);
for (auto &It : *DeviceOffloadEntries.getExclusiveAccessor()) {
const char *Kind = "kernel";
if (It.second->isCTor())
if (It.second.isCTor())
Kind = "constructor";
else if (It.second->isDTor())
else if (It.second.isDTor())
Kind = "destructor";
else if (It.second->isLink())
else if (It.second.isLink())
Kind = "link";
else if (It.second->isGlobal())
else if (It.second.isGlobal())
Kind = "global var.";
fprintf(stderr, " %11s: %s\n", Kind, It.second->getNameAsCStr());
fprintf(stderr, " %11s: %s\n", Kind, It.second.getNameAsCStr());
}
}