diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 11a91958d7484..c4d123e0f539c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4255,6 +4255,7 @@ def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)"> def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">; def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">; def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">; +def SPIRV_IsSamplerType : CPred<"::llvm::isa<::mlir::spirv::SamplerType>($_self)">; def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">; def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">; @@ -4298,6 +4299,8 @@ def SPIRV_AnyStruct : DialectType; def SPIRV_AnySampledImage : DialectType; +def SPIRV_AnySampler : DialectType; def SPIRV_AnyTensorArm : DialectType; @@ -4311,7 +4314,7 @@ def SPIRV_Type : AnyTypeOf<[ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage, - SPIRV_AnyImage, SPIRV_AnyTensorArm + SPIRV_AnySampler, SPIRV_AnyImage, SPIRV_AnyTensorArm ]>; def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; @@ -4412,6 +4415,7 @@ def SPIRV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 2 def SPIRV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; def SPIRV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>; def SPIRV_OC_OpTypeImage : I32EnumAttrCase<"OpTypeImage", 25>; +def SPIRV_OC_OpTypeSampler : I32EnumAttrCase<"OpTypeSampler", 26>; def SPIRV_OC_OpTypeSampledImage : I32EnumAttrCase<"OpTypeSampledImage", 27>; def SPIRV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; def SPIRV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; @@ -4449,6 +4453,7 @@ def SPIRV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeCons def SPIRV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPIRV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; def SPIRV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>; +def SPIRV_OC_OpSampledImage : I32EnumAttrCase<"OpSampledImage", 86>; def SPIRV_OC_OpImageSampleImplicitLod : I32EnumAttrCase<"OpImageSampleImplicitLod", 87>; def SPIRV_OC_OpImageSampleExplicitLod : I32EnumAttrCase<"OpImageSampleExplicitLod", 88>; def SPIRV_OC_OpImageSampleProjDrefImplicitLod : I32EnumAttrCase<"OpImageSampleProjDrefImplicitLod", 93>; @@ -4661,7 +4666,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpMemoryModel, SPIRV_OC_OpEntryPoint, SPIRV_OC_OpExecutionMode, SPIRV_OC_OpCapability, SPIRV_OC_OpTypeVoid, SPIRV_OC_OpTypeBool, SPIRV_OC_OpTypeInt, SPIRV_OC_OpTypeFloat, SPIRV_OC_OpTypeVector, - SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampledImage, + SPIRV_OC_OpTypeMatrix, SPIRV_OC_OpTypeImage, SPIRV_OC_OpTypeSampler, + SPIRV_OC_OpTypeSampledImage, SPIRV_OC_OpTypeArray, SPIRV_OC_OpTypeRuntimeArray, SPIRV_OC_OpTypeStruct, SPIRV_OC_OpTypePointer, SPIRV_OC_OpTypeFunction, SPIRV_OC_OpTypeForwardPointer, SPIRV_OC_OpConstantTrue, SPIRV_OC_OpConstantFalse, SPIRV_OC_OpConstant, @@ -4677,6 +4683,7 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpVectorInsertDynamic, SPIRV_OC_OpVectorShuffle, SPIRV_OC_OpCompositeConstruct, SPIRV_OC_OpCompositeExtract, SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose, + SPIRV_OC_OpSampledImage, SPIRV_OC_OpImageSampleImplicitLod, SPIRV_OC_OpImageSampleExplicitLod, SPIRV_OC_OpImageSampleProjDrefImplicitLod, SPIRV_OC_OpImageFetch, SPIRV_OC_OpImageDrefGather, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td index e23efa57e5e53..e5ff4c5d96b4a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td @@ -65,6 +65,54 @@ def SPIRV_SampledImageTransform : StrFunc<"llvm::cast<::mlir::spirv::SampledImag // ----- +def SPIRV_SampledImageOp : SPIRV_Op<"SampledImage", + [Pure, + TypesMatchWith<"type of 'result' wraps the image type of 'image'", + "result", "image", + "::llvm::cast($_self).getImageType()">, + SPIRV_DimIsNot<"image", ["SubpassData"]>, + SPIRV_SampledOperandIs<"image", ["SamplerUnknown", "NeedSampler"]>]> { + let summary = "Create a sampled image, containing both a sampler and an image."; + + let description = [{ + Result Type must be OpTypeSampledImage whose Image Type is the same as + the type of the Image operand. + + Image must be an object whose type is an OpTypeImage, whose Sampled + operand is 0 or 1. The Dim operand of the underlying OpTypeImage must + not be SubpassData. Additionally, starting with version 1.6, the Dim + operand must not be Buffer. + + Sampler must be an object of a type made by OpTypeSampler. + + + + #### Example: + + ```mlir + %0 = spirv.SampledImage %image, %sampler : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + ``` + }]; + + let arguments = (ins + SPIRV_AnyImage:$image, + SPIRV_AnySampler:$sampler + ); + + let results = (outs + SPIRV_AnySampledImage:$result + ); + + let assemblyFormat = [{ + $image `,` $sampler attr-dict `:` type($image) `,` type($sampler) + `->` type($result) + }]; + + let hasVerifier = 0; +} + +// ----- + def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure, SPIRV_DimIs<"sampled_image", ["Dim2D", "Cube", "Rect"], SPIRV_SampledImageTransform.result>, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 4a0c29d4b5d90..9864f644aa93e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -228,6 +228,16 @@ class SampledImageType Type getImageType() const; }; +// SPIR-V sampler type +class SamplerType : public Type::TypeBase { +public: + using Base::Base; + + static constexpr StringLiteral name = "spirv.sampler"; + + static SamplerType get(MLIRContext *context); +}; + /// SPIR-V struct type. Two kinds of struct types are supported: /// - Literal: a literal struct type is uniqued by its fields (types + offset /// info + decoration info). diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 5782b42dba026..036d48f0fd637 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -861,6 +861,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const { return parseRuntimeArrayType(*this, parser); if (keyword == "sampled_image") return parseSampledImageType(*this, parser); + if (keyword == "sampler") + return SamplerType::get(getContext()); if (keyword == "struct") return parseStructType(*this, parser); if (keyword == "matrix") @@ -907,6 +909,8 @@ static void print(SampledImageType type, DialectAsmPrinter &os) { os << "sampled_image<" << type.getImageType() << ">"; } +static void print(SamplerType type, DialectAsmPrinter &os) { os << "sampler"; } + static void print(StructType type, DialectAsmPrinter &os) { FailureOr cyclicPrint; @@ -1001,8 +1005,8 @@ static void print(TensorArmType type, DialectAsmPrinter &os) { void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case( - [&](auto type) { print(type, os); }) + ImageType, SampledImageType, SamplerType, StructType, MatrixType, + TensorArmType>([&](auto type) { print(type, os); }) .DefaultUnreachable("Unhandled SPIR-V type"); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 331d98c1d9313..c4dd4cea778d7 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -57,6 +57,7 @@ class TypeExtensionVisitor { for (Type elementType : concreteType.getElementTypes()) add(elementType); }) + .Case([](auto) { /* no extensions */ }) .DefaultUnreachable("Unhandled type"); } @@ -107,6 +108,7 @@ class TypeCapabilityVisitor { for (Type elementType : concreteType.getElementTypes()) add(elementType); }) + .Case([](auto) { /* no capabilities */ }) .DefaultUnreachable("Unhandled type"); } @@ -794,6 +796,14 @@ SampledImageType::verifyInvariants(function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// SamplerType +//===----------------------------------------------------------------------===// + +SamplerType SamplerType::get(MLIRContext *context) { + return Base::get(context); +} + //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// @@ -1331,5 +1341,6 @@ TensorArmType::verifyInvariants(function_ref emitError, void SPIRVDialect::registerTypes() { addTypes(); + RuntimeArrayType, SampledImageType, SamplerType, StructType, + TensorArmType>(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index cc6302126d64a..0faa5f0f29d7d 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -183,6 +183,7 @@ LogicalResult spirv::Deserializer::processInstruction( case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: case spirv::Opcode::OpTypeImage: + case spirv::Opcode::OpTypeSampler: case spirv::Opcode::OpTypeSampledImage: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index f98236c5daece..9557b4647958d 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1150,6 +1150,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, return processFunctionType(operands); case spirv::Opcode::OpTypeImage: return processImageType(operands); + case spirv::Opcode::OpTypeSampler: + return processSamplerType(operands); case spirv::Opcode::OpTypeSampledImage: return processSampledImageType(operands); case spirv::Opcode::OpTypeRuntimeArray: @@ -1634,6 +1636,15 @@ spirv::Deserializer::processSampledImageType(ArrayRef operands) { return success(); } +LogicalResult +spirv::Deserializer::processSamplerType(ArrayRef operands) { + if (operands.size() != 1) + return emitError(unknownLoc, "OpTypeSampler must have no parameters"); + + typeMap[operands[0]] = spirv::SamplerType::get(context); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 50c935036158c..e0743503acc5b 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -317,6 +317,8 @@ class Deserializer { LogicalResult processSampledImageType(ArrayRef operands); + LogicalResult processSamplerType(ArrayRef operands); + LogicalResult processRuntimeArrayType(ArrayRef operands); LogicalResult processStructType(ArrayRef operands); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index c21cb27b072f1..aaa80470f40e1 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -728,6 +728,11 @@ LogicalResult Serializer::prepareBasicType( return processTypeDecoration(loc, runtimeArrayType, resultID); } + if (isa(type)) { + typeEnum = spirv::Opcode::OpTypeSampler; + return success(); + } + if (auto sampledImageType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeSampledImage; uint32_t imageTypeID = 0; diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir index 6e4126172f670..ef04b949c5219 100644 --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 // ----- func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> { - // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}} + // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V sampler type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}} %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1> return %0: vector<4x2xi1> } diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir index 12b5f2ce62a68..c4f90cfcb9a49 100644 --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -468,3 +468,47 @@ func.func @gard_too_many_args(%arg0 : !spirv.sampled_image>, vector<2xf32>, vector<2xf32>, vector<2xf32>, vector<2xf32> -> vector<4xf32> spirv.Return } + +//===----------------------------------------------------------------------===// +// spirv.SampledImage +//===----------------------------------------------------------------------===// + +// ----- + +func.func @sampled_image(%arg0 : !spirv.image, %arg1 : !spirv.sampler) -> () { + // CHECK: spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + spirv.Return +} + +// ----- + +func.func @sampled_image_sampler_unknown(%arg0 : !spirv.image, %arg1 : !spirv.sampler) -> () { + // CHECK: spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + spirv.Return +} + +// ----- + +func.func @sampled_image_error(%arg0 : !spirv.image, %arg1 : !spirv.sampler) -> () { + // expected-error @+1 {{type of 'result' wraps the image type of 'image'}} + %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + spirv.Return +} + +// ----- + +func.func @sampled_image_dim_subpassdata(%arg0 : !spirv.image, %arg1 : !spirv.sampler) -> () { + // expected-error @+1 {{sampled image Dim must not be SubpassData or Buffer, got SubpassData}} + %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + spirv.Return +} + +// ----- + +func.func @sampled_image_sampled_operand(%arg0 : !spirv.image, %arg1 : !spirv.sampler) -> () { + // expected-error @+1 {{the sampled operand of the underlying image must be SamplerUnknown or NeedSampler}} + %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + spirv.Return +} diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 710673b73cee5..99443a13e0ec3 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -234,6 +234,15 @@ func.func private @image_parameters_nocomma_5(!spirv.image () + +// ----- + //===----------------------------------------------------------------------===// // SampledImageType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir index 3593d9b0e9b38..b664561ce9b01 100644 --- a/mlir/test/Target/SPIRV/image-ops.mlir +++ b/mlir/test/Target/SPIRV/image-ops.mlir @@ -61,3 +61,13 @@ spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @sampled_image(%arg0 : !spirv.image, %arg1 : !spirv.sampler) "None" { + // CHECK: {{%.*}} = spirv.SampledImage {{%.*}}, {{%.*}} : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + %0 = spirv.SampledImage %arg0, %arg1 : !spirv.image, !spirv.sampler -> !spirv.sampled_image> + spirv.Return + } +} diff --git a/mlir/test/Target/SPIRV/sampled-image.mlir b/mlir/test/Target/SPIRV/sampled-image.mlir index ff068208540f4..4f6f3256acbac 100644 --- a/mlir/test/Target/SPIRV/sampled-image.mlir +++ b/mlir/test/Target/SPIRV/sampled-image.mlir @@ -14,4 +14,7 @@ spirv.module Logical GLSL450 requires #spirv.vce>, UniformConstant> spirv.GlobalVariable @var2 bind(0, 0) : !spirv.ptr>, UniformConstant> + + // CHECK: !spirv.ptr + spirv.GlobalVariable @var3 bind(0, 2) : !spirv.ptr }