diff --git a/lib/Analysis/PtrAnalysis.cpp b/lib/Analysis/PtrAnalysis.cpp index 00e454ea..894f5300 100644 --- a/lib/Analysis/PtrAnalysis.cpp +++ b/lib/Analysis/PtrAnalysis.cpp @@ -394,6 +394,8 @@ void PtrAnalysis::visitOperandMakeRange( auto start = rangeOp.getStart(); auto end = rangeOp.getEnd(); auto stride = (end - start + shape[0] - 1) / shape[0]; + assert(stride == 1 && + "Expect make_range op to always return tensor of stride 1"); state.offsets.push_back(rewriter.getIndexAttr(start)); state.sizes.push_back(rewriter.getIndexAttr(shape[0])); diff --git a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir b/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir index 5a1f004a..b7780324 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir @@ -9,26 +9,35 @@ module { %c3 = arith.constant 3 : index %i_c3 = arith.constant 3 : i32 %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 4 + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 4 + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + // gep operand is another gep' output, which is passed into the loop as varible, used after update %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> + %8 = tt.broadcast %7 : (tensor<256x1xi32>) -> tensor<256x256xi32> // sizes: [256, 256], offsets: [0, 0], strides: [1, 0] + %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> %10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32> + %11 = tt.broadcast %10 : (tensor<1x256xi32>) -> tensor<256x256xi32> // sizes: [256, 256], offsets: [0, 256], strides: [0, 1] + %12 = arith.addi %8, %11 : tensor<256x256xi32> // sizes: [256, 256], offsets: [0, 256], strides: [1, 1] + %13 = tt.expand_dims %ptr {axis = 1 : i32} : (tensor<256x!tt.ptr>) -> tensor<256x1x!tt.ptr> %14 = tt.broadcast %13 : (tensor<256x1x!tt.ptr>) -> tensor<256x256x!tt.ptr> + %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> - // source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [5, 1] + // source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [2, 1] + // perform load %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256xbf16> tt.store %15, %16 : tensor<256x256xbf16> @@ -44,7 +53,7 @@ module { } // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index diff --git a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir b/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir index c2292bb3..c7019f55 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir @@ -11,13 +11,13 @@ module { %c3 = arith.constant 3 : index %c12 = arith.constant 12 : index %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 4 + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 4 + // source: arg0, sizes: 256, offsets: 1024, strides: 1 %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr> %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg1, sizes: 256, offsets: 1024, strides: 4 + // source: arg1, sizes: 256, offsets: 1024, strides: 1 %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { // perform load %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> @@ -26,7 +26,7 @@ module { %cast3 = arith.index_cast %c3 : index to i32 %6 = tt.splat %cast3 : (i32) -> tensor<256xi32> %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 4 + // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 1 %arg2_iter = arith.addi %arg2, %c3 : index %arg3_iter = arith.addi %arg3, %c3 : index %arg4_iter = arith.addi %arg4, %c3 : index @@ -35,7 +35,7 @@ module { %cast8 = arith.index_cast %8 : index to i32 %9 = tt.splat %cast8 : (i32) -> tensor<256xi32> %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 4 + // source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 1 scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index } tt.return @@ -43,29 +43,28 @@ module { } // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index // CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> // CHECK: %[[VAL_14:.*]]:7 = scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_9]], %[[VAL_19:.*]] = %[[VAL_13]], %[[VAL_20:.*]] = %[[VAL_10]], %[[VAL_21:.*]] = %[[VAL_6]], %[[VAL_22:.*]] = %[[VAL_6]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) { // CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16> // CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> // CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16> // CHECK: memref.tensor_store %[[VAL_24]], %[[VAL_19]] : memref<256xbf16, strided<[?], offset: ?>> // CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index -// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> // CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_10]] : index // CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_20]], %[[VAL_10]] : index // CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index // CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]] : index // CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_22]], %[[VAL_31]] : index -// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> // CHECK: scf.yield %[[VAL_27]], %[[VAL_26]], %[[VAL_28]], %[[VAL_33]], %[[VAL_29]], %[[VAL_25]], %[[VAL_32]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index // CHECK: } // CHECK: return diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir index 1591a506..8aa72b13 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir @@ -9,17 +9,17 @@ module { %c3 = arith.constant 3 : index %i_c3 = arith.constant 3 : i32 %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 4 + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 4 + // source: arg0, sizes: 256, offsets: 1024, strides: 1 // gep operand is another gep' output, which is passed into the loop as varible, used after update %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { // pointer updates %4 = tt.splat %i_c3 : (i32) -> tensor<256xi32> // sizes: 256, offsets: 3, strides: 0 %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i, strides: 4 + // source: arg0, sizes: 256, offsets: 1024 + i, strides: 1 // perform load %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16> tt.store %ptr_iter, %3 : tensor<256xbf16> @@ -80,7 +80,7 @@ module { } // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir index ef4d19fa..fd185fde 100644 --- a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir +++ b/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir @@ -9,10 +9,10 @@ module { %c3 = arith.constant 3 : index %i_c3 = arith.constant 3 : i32 %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 4 + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 4 + // source: arg0, sizes: 256, offsets: 1024, strides: 1 // Example 2, gep operand is another gep's output, which is passed into the loop as varible, used before update %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { // perform load @@ -36,7 +36,7 @@ module { } // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index diff --git a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir index eb97907a..47fdf27c 100644 --- a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir +++ b/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir @@ -13,17 +13,17 @@ module { //%3: splat(%0) + range(0, 1024) //%3: offset = %0, size = 1024, stride = 1 // vector and scalar are both constant - %4 = tt.make_range {end = 4096 : i32, start = 2048 : i32}:tensor<1024xi32> + %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> %c10 = arith.constant 10 : i32 %5 = tt.splat %c10 : (i32) -> tensor<1024xi32> %6 = arith.muli %5, %4 : tensor<1024xi32> //%6: splat(%c10)*range(2048, 4096); - //%6: offset = %c10*2048, size = 1024, stride = %c10*2 + //%6: offset = %c10*2048, size = 1024, stride = %c10*1 %7 = arith.addi %3, %6 : tensor<1024xi32> - //%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*2+1 + //%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*1+1 %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr> %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*2+1 + //source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*1+1 %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr> %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> //source=%arg1, offset = pid0, size = 1024, stride = 1 @@ -34,14 +34,15 @@ module { } // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { +// CHECK: %[[VAL_6:.*]] = arith.constant 11 : index // CHECK: %[[VAL_7:.*]] = arith.constant 20480 : index // CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}21] : memref<*xbf16> to memref<1024xbf16, strided<[21], offset: ?>> +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}%[[VAL_6]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> // CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_3]] : i32 to index // CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> // CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1024xbf16> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[21], offset: ?>> to memref<1024xbf16> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> // CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16> // CHECK: memref.tensor_store %[[VAL_14]], %[[VAL_12]] : memref<1024xbf16, strided<[1], offset: ?>> // CHECK: return diff --git a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir index 99812799..23b1f23c 100644 --- a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir +++ b/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir @@ -13,16 +13,16 @@ module { //%3: splat(%0) + range(0, 1024) //%3: offset = %0, size = 1024, stride = 1 // vector is constant, scalar is value - %4 = tt.make_range {end = 4096 : i32, start = 2048 : i32}:tensor<1024xi32> + %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> %5 = tt.splat %arg2 : (i32) -> tensor<1024xi32> %6 = arith.muli %5, %4 : tensor<1024xi32> - //%6: splat(%arg2)*range(2048, 4096); - //%6: offset = %arg2*2048, size = 1024, stride = %arg2*2 + //%6: splat(%arg2)*range(2048, 3072); + //%6: offset = %arg2*2048, size = 1024, stride = %arg2*1 %7 = arith.addi %3, %6 : tensor<1024xi32> - //%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*2+1 + //%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*1+1 %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr> %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*2+1 + //source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*1+1 %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr> %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> //source=%arg1: offset = pid0, size = 1024, stride = 1 @@ -34,14 +34,12 @@ module { // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index // CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2048 : index // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[ARG_6]] : i32 to index // CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_2]] : i32 to index // CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index -// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : index // CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_6]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_10]], %[[VAL_6]] : index // CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [1024], strides: {{\[}}%[[VAL_14]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> // CHECK: %[[VAL_16:.*]] = arith.index_cast %[[ARG_6]] : i32 to index // CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_16]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> diff --git a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir index f09e2613..5bd3d2b8 100644 --- a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir +++ b/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir @@ -6,24 +6,24 @@ module { %arg1 : !tt.ptr ) { - %0 = tt.make_range {end = 1024 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 2 + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [2,0] + // offset = [512,0], size = [256,1], stride = [1,0] %2 = tt.broadcast %1 : (tensor<256x1xi32>) -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [2,0] - %5 = tt.make_range {end = 1408 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 3 + // offset = [512,0], size = [256,128], stride = [1,0] + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,3] + // offset = [0,1024], size = [1,128], stride = [0,1] %7 = tt.broadcast %6 : (tensor<1x128xi32>) -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,3] + // offset = [0,1024], size = [256,128], stride = [0,1] %c6 = arith.constant 6 : i32 %splat6 = tt.splat %c6 : (i32) -> tensor<256x128xi32> %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,18] + // offset = [0,6144], size = [256,128], stride = [0,6] %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [2,18] + // offset = [512,6144], size = [256,128], stride = [1,6] %17 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x128x!tt.ptr> %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> %19 = tt.load %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x128xbf16> @@ -33,10 +33,11 @@ module { } // CHECK-LABEL: func.func @kernel( // CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [2, 18] : memref<*xbf16> to memref<256x128xbf16, strided<[2, 18], offset: 6656>> +// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: {{\[}}1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> // CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[2, 18], offset: 6656>> to memref<256x128xbf16> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> // CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<256x128xbf16> -// CHECK: memref.tensor_store %[[VAL_9]], %[[VAL_7]] : memref<256x128xbf16, strided<[2, 18], offset: 6656>> +// CHECK: memref.tensor_store %[[VAL_9]], %[[VAL_7]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir index 19d030ef..c586fb36 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir @@ -17,16 +17,16 @@ module { // offset = [0, 0], size = [1, 1024], strides = [0, 1] %8 = tt.broadcast %7 : (tensor<1x1024xi32>) -> tensor<1024x1024xi32> // offset = [0, 0], size = [1024, 1024], strides = [0, 1] - %9 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 2 + %9 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<1024xi32>) -> tensor<1024x1xi32> - // offset = [0, 0], size = [1024, 1], strides = [2, 0] + // offset = [0, 0], size = [1024, 1], strides = [1, 0] %11 = tt.broadcast %10 : (tensor<1024x1xi32>) -> tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [2, 0] + // offset = [0, 0], size = [1024, 1024], strides = [1, 0] %12 = arith.addi %8, %11 : tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [2, 1] + // offset = [0, 0], size = [1024, 1024], strides = [1, 1] %13 = tt.addptr %5, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> - // source = arg1, offset = [pid * %arg2, 0], size = [1024, 1024], strides = [2, 1] + // source = arg1, offset = [pid * %arg2, 0], size = [1024, 1024], strides = [1, 1] %14 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024x1024xf32> %17 = math.exp %14 : tensor<1024x1024xf32> %18 = arith.muli %0, %arg3 : i32 @@ -39,7 +39,7 @@ module { %22 = tt.broadcast %21 : (tensor<1024x1x!tt.ptr>) -> tensor<1024x1024x!tt.ptr> // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [0, 0] %23 = tt.addptr %22, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [2, 1] + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [1, 1] tt.store %23, %17 : tensor<1024x1024xf32> tt.return } @@ -48,9 +48,9 @@ module { // CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { // CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024, 1024], strides: [2, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[2, 1], offset: ?>> +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>> // CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024x1024xf32> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024x1024xf32, strided<[2, 1], offset: ?>> to memref<1024x1024xf32> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024x1024xf32, strided<[1, 1], offset: ?>> to memref<1024x1024xf32> // CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024x1024xf32> // CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_12]] : tensor<1024x1024xf32>) outs(%[[VAL_12]] : tensor<1024x1024xf32>) { // CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): @@ -59,7 +59,7 @@ module { // CHECK: } -> tensor<1024x1024xf32> // CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [2, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[2, 1], offset: ?>> -// CHECK: memref.tensor_store %[[VAL_20:.*]], %[[VAL_19]] : memref<1024x1024xf32, strided<[2, 1], offset: ?>> +// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>> +// CHECK: memref.tensor_store %[[VAL_20:.*]], %[[VAL_19]] : memref<1024x1024xf32, strided<[1, 1], offset: ?>> // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir index af159e08..bc27e898 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir @@ -16,14 +16,14 @@ module { %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> %6 = tt.broadcast %5 : (tensor<1x128xi32>) -> tensor<128x128xi32> // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 384 : i32, start = 128 : i32} : tensor<128xi32> + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> %8 = tt.expand_dims %7 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> %9 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [2, 0] + // offset = [128, 0], size = [128, 128], strides = [1, 0] %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [2, 1] + // offset = [128, 0], size = [128, 128], strides = [1, 1] %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [2, 1] + // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] %12 = tt.load %11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> %17 = math.exp %12 : tensor<128x128xf32> %sum_next = arith.addf %sum_iter, %17 : tensor<128x128xf32> @@ -36,19 +36,19 @@ module { %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> %6 = tt.broadcast %5 : (tensor<1x128xi32>) -> tensor<128x128xi32> // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 384 : i32, start = 128 : i32} : tensor<128xi32> + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> %8 = tt.expand_dims %7 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> %9 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [2, 0] + // offset = [128, 0], size = [128, 128], strides = [1, 0] %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [2, 1] + // offset = [128, 0], size = [128, 128], strides = [1, 1] %18 = arith.muli %0, %arg3 : i32 %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 // source = arg0, offset = %18, size = 1, strides = 0 %20 = tt.splat %19 : (!tt.ptr) -> tensor<128x128x!tt.ptr> // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [2, 1] + // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] tt.store %21, %sum_out : tensor<128x128xf32> tt.return } @@ -66,9 +66,9 @@ module { // CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i32 to index // CHECK: %[[VAL_17:.*]]:2 = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_14]]) -> (tensor<128x128xf32>, index) { // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> +// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> // CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_22]], %[[VAL_23]] : memref<128x128xf32, strided<[2, 1], offset: ?>> to memref<128x128xf32> +// CHECK: memref.copy %[[VAL_22]], %[[VAL_23]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> // CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<128x128xf32> // CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_24]] : tensor<128x128xf32>) outs(%[[VAL_24]] : tensor<128x128xf32>) { // CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): @@ -86,7 +86,7 @@ module { // CHECK: %[[VAL_37:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_38:.*]] = arith.index_cast %[[VAL_37]] : i32 to index // CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_8]] : index -// CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_39]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> -// CHECK: memref.tensor_store %[[VAL_41:.*]]#0, %[[VAL_40]] : memref<128x128xf32, strided<[2, 1], offset: ?>> +// CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_39]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> +// CHECK: memref.tensor_store %[[VAL_41:.*]]#0, %[[VAL_40]] : memref<128x128xf32, strided<[1, 1], offset: ?>> // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir index 4275e28b..e6dd9202 100644 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir +++ b/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir @@ -10,15 +10,15 @@ module { %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> %6 = tt.broadcast %5 : (tensor<1x128xi32>) -> tensor<128x128xi32> // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 384 : i32, start = 128 : i32} : tensor<128xi32> + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> // offset = 128, size = 128, strides = 1 %8 = tt.expand_dims %7 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> %9 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [2, 0] + // offset = [128, 0], size = [128, 128], strides = [1, 0] %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [2, 1] + // offset = [128, 0], size = [128, 128], strides = [1, 1] %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [2, 1] + // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] %12 = tt.load %11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> %17 = math.exp %12 : tensor<128x128xf32> %18 = arith.muli %0, %arg3 : i32 @@ -27,7 +27,7 @@ module { %20 = tt.splat %19 : (!tt.ptr) -> tensor<128x128x!tt.ptr> // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [2, 1] + // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] tt.store %21, %17 : tensor<128x128xf32> tt.return } @@ -38,9 +38,9 @@ module { // CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 // CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index // CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_8]] : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> +// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> // CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_13]] : memref<128x128xf32, strided<[2, 1], offset: ?>> to memref<128x128xf32> +// CHECK: memref.copy %[[VAL_12]], %[[VAL_13]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> // CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<128x128xf32> // CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_14]] : tensor<128x128xf32>) outs(%[[VAL_14]] : tensor<128x128xf32>) { // CHECK: ^bb0(%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32): @@ -50,7 +50,7 @@ module { // CHECK: %[[VAL_19:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 // CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : i32 to index // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [2, 1] : memref<*xf32> to memref<128x128xf32, strided<[2, 1], offset: ?>> -// CHECK: memref.tensor_store %[[VAL_23:.*]], %[[VAL_22]] : memref<128x128xf32, strided<[2, 1], offset: ?>> +// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> +// CHECK: memref.tensor_store %[[VAL_23:.*]], %[[VAL_22]] : memref<128x128xf32, strided<[1, 1], offset: ?>> // CHECK: return // CHECK: }