diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 079e1e2a8ac67..55ade0ae8eeec 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -102,18 +102,46 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, xegpu::TensorDescType descType, TypedValue src) { MemRefType srcTy = src.getType(); + assert(srcTy.isStrided() && "Expected strided memref type"); auto [strides, offset] = srcTy.getStridesAndOffset(); + bool isStatic = true; + + // Memref is dynamic if any of its shape, offset or strides is dynamic. + if (!srcTy.hasStaticShape()) + isStatic = false; + + if (!ShapedType::isStatic(offset)) + isStatic = false; + + for (auto stride : strides) { + if (!ShapedType::isStatic(stride)) { + isStatic = false; + break; + } + } xegpu::CreateNdDescOp ndDesc; - if (srcTy.hasStaticShape()) { + if (isStatic) { ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src); } else { - // In case of any dynamic shapes, source's shape and strides have to be + // In case of ranked dynamic memref, instead of passing on the memref, + // i64 base address, source's offset, shape and strides have to be // explicitly provided. auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src); - ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, - meta.getConstifiedMixedSizes(), - meta.getConstifiedMixedStrides()); + auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, meta.getBaseBuffer()); + auto offset = meta.getOffset(); + auto elemByteSize = srcTy.getElementTypeBitWidth() / 8; + auto offsetInBytes = arith::MulIOp::create( + rewriter, loc, offset, + arith::ConstantIndexOp::create(rewriter, loc, elemByteSize)); + auto adjustedBaseAddr = arith::AddIOp::create( + rewriter, loc, baseAddrIndex.getResult(), offsetInBytes); + auto adjustedAddrI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr); + ndDesc = xegpu::CreateNdDescOp::create( + rewriter, loc, descType, adjustedAddrI64, + meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides()); } return ndDesc; diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir index ae5141db16c09..c77efa03f3483 100644 --- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir @@ -9,10 +9,17 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto // CHECK-LABEL: @load_1D_vector( // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[COLLAPSED]] -// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, +// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref, index, index, index +// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// CHECK-SAME: : memref -> index +// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index +// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32], +// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32, // CHECK-SAME: boundary_check = false // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32> // CHECK: return %[[VEC]] @@ -29,10 +36,16 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>, // CHECK-LABEL: @load_2D_vector( // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[COLLAPSED]] -// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32> +// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// CHECK-SAME: : memref -> index +// CHECK: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index +// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32], +// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] @@ -48,9 +61,15 @@ func.func @load_dynamic_source(%source: memref, // CHECK-LABEL: @load_dynamic_source( // CHECK-SAME: %[[SRC:.+]]: memref, // CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] -// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref -> index +// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1], +// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32> // CHECK: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir index 1a10d917623cc..3c11313d05536 100644 --- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir @@ -11,10 +11,17 @@ func.func @store_1D_vector(%vec: vector<8xf32>, // CHECK-SAME: %[[VEC:.+]]: vector<8xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index // CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[COLLAPSED]] -// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, +// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref, index, index, index +// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// CHECK-SAME: : memref -> index +// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index +// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32], +// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32, // CHECK-SAME: boundary_check = false // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32> @@ -31,10 +38,16 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>, // CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, // CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index // CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc -// CHECK-SAME: %[[COLLAPSED]] -// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32> +// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// CHECK-SAME: : memref -> index +// CHECK: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index +// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32], +// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32> // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> // ----- @@ -50,9 +63,15 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, // CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, // CHECK-SAME: %[[SRC:.+]]: memref, // CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index // CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] -// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] -// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref -> index +// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1], +// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor // CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32> // ----- diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index 8bb272b1fe5fc..b58f9b30ed726 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -48,10 +48,16 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>, // LOAD-ND-LABEL: @load_2D_vector( // LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // LOAD-ND-SAME: %[[OFFSET:.+]]: index +// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[COLLAPSED]] -// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, +// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// LOAD-ND-SAME: : memref -> index +// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index +// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32], +// LOAD-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32, // LOAD-ND-SAME: boundary_check = false // LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32> // LOAD-ND: return %[[VEC]] @@ -147,9 +153,16 @@ gpu.func @load_dynamic_source(%source: memref, // LOAD-ND-LABEL: @load_dynamic_source( // LOAD-ND-SAME: %[[SRC:.+]]: memref, // LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] -// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref -> index +// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1], +// LOAD-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32, +// LOAD-ND-SAME: #xegpu.block_tdesc_attr> // LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32> // LOAD-ND: return %[[VEC]] @@ -184,8 +197,15 @@ gpu.func @load_dynamic_source2(%source: memref, // LOAD-ND-LABEL: @load_dynamic_source2( // LOAD-ND-SAME: %[[SRC:.+]]: memref, // LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> +// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [8, 16], strides : [16, 1] : +// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> // LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> -> vector<8x16xf32> // LOAD-ND: return %[[VEC]] : vector<8x16xf32> @@ -459,11 +479,15 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of // LOAD-ND-LABEL: @load_from_subview_2D( // LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 2 : index // LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[SUBVIEW]] -// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16, -// LOAD-ND-SAME: boundary_check = false +// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]] +// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256, 256], strides : [4096, 1] : +// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> // LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16> // LOAD-ND: return %[[VEC]] diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir index 43a1a7206e2cc..66da64225678e 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -15,10 +15,17 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // STORE-ND-SAME: %[[OFFSET:.+]]: index +// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] -// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[COLLAPSED]] -// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, +// STORE-ND: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]] +// STORE-ND-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref, index, index, index +// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// STORE-ND-SAME: : memref -> index +// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index +// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64 +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32], +// STORE-ND-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32, // STORE-ND-SAME: boundary_check = false // STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32> @@ -50,10 +57,16 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, // STORE-ND-SAME: %[[OFFSET:.+]]: index +// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0] -// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[COLLAPSED]] -// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, +// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// STORE-ND-SAME: : memref -> index +// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index +// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32], +// STORE-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32, // STORE-ND-SAME: boundary_check = false // STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32> @@ -86,9 +99,15 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>, // STORE-ND-SAME: %[[SRC:.+]]: memref, // STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index // STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0] -// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] -// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]] +// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]] +// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref -> index +// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1], +// STORE-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor // STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32> // STORE-SCATTER-LABEL: @store_dynamic_source( @@ -293,12 +312,16 @@ gpu.func @store_to_subview(%vec: vector<8xf16>, // STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>, // STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 2 : index // STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> // STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0] -// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// STORE-ND-SAME: %[[COLLAPSED]] -// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16, -// STORE-ND-SAME: boundary_check = false +// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]] +// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] +// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index +// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index +// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64 +// STORE-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256], strides : [1] : i64 -> +// STORE-ND-SAME: !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_attr> // STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16> // STORE-SCATTER-LABEL: @store_to_subview(