Skip to content

Commit

Permalink
Fix make_range issue (#63)
Browse files Browse the repository at this point in the history
Nightlies are failing because our lit tests assume tt.make_range can produce tensors of arbitrary strides; this is no longer true after upstream triton updates the op's verifier. I updated these tests and add an assert that we're always dealing with stride 1.

Fixes #53 #60
  • Loading branch information
nhat-nguyen committed Nov 27, 2023
1 parent de797bb commit f90b031
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 84 deletions.
2 changes: 2 additions & 0 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ void PtrAnalysis::visitOperandMakeRange(
auto start = rangeOp.getStart();
auto end = rangeOp.getEnd();
auto stride = (end - start + shape[0] - 1) / shape[0];
assert(stride == 1 &&
"Expect make_range op to always return tensor of stride 1");

state.offsets.push_back(rewriter.getIndexAttr(start));
state.sizes.push_back(rewriter.getIndexAttr(shape[0]));
Expand Down
19 changes: 14 additions & 5 deletions test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,35 @@ module {
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%0 = tt.splat %arg0 : (!tt.ptr<bf16>) -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 4
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 1

%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg0, sizes: 256, offsets: 1024, strides: 4
// source: arg0, sizes: 256, offsets: 1024, strides: 1

// gep operand is another gep' output, which is passed into the loop as varible, used after update
%_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr<bf16>>) {
%6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<256xi32>) -> tensor<256x1xi32>

%8 = tt.broadcast %7 : (tensor<256x1xi32>) -> tensor<256x256xi32>
// sizes: [256, 256], offsets: [0, 0], strides: [1, 0]

%9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32>
%10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<256xi32>) -> tensor<1x256xi32>

%11 = tt.broadcast %10 : (tensor<1x256xi32>) -> tensor<256x256xi32>
// sizes: [256, 256], offsets: [0, 256], strides: [0, 1]

%12 = arith.addi %8, %11 : tensor<256x256xi32>
// sizes: [256, 256], offsets: [0, 256], strides: [1, 1]

%13 = tt.expand_dims %ptr {axis = 1 : i32} : (tensor<256x!tt.ptr<bf16>>) -> tensor<256x1x!tt.ptr<bf16>>
%14 = tt.broadcast %13 : (tensor<256x1x!tt.ptr<bf16>>) -> tensor<256x256x!tt.ptr<bf16>>

%15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr<bf16>>, tensor<256x256xi32>
// source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [5, 1]
// source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [2, 1]

// perform load
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256xbf16>
tt.store %15, %16 : tensor<256x256xbf16>
Expand All @@ -44,7 +53,7 @@ module {
}
// CHECK-LABEL: func.func @kernel(
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) {
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index
Expand Down
21 changes: 10 additions & 11 deletions test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ module {
%c3 = arith.constant 3 : index
%c12 = arith.constant 12 : index
%0 = tt.splat %arg0 : (!tt.ptr<bf16>) -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 4
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 1
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg0, sizes: 256, offsets: 1024, strides: 4
// source: arg0, sizes: 256, offsets: 1024, strides: 1
%3 = tt.splat %arg1 : (!tt.ptr<bf16>) -> tensor<256x!tt.ptr<bf16>>
%4 = tt.addptr %3, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg1, sizes: 256, offsets: 1024, strides: 4
// source: arg1, sizes: 256, offsets: 1024, strides: 1
%_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr<bf16>>, index, tensor<256x!tt.ptr<bf16>>, index) {
// perform load
%5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16>
Expand All @@ -26,7 +26,7 @@ module {
%cast3 = arith.index_cast %c3 : index to i32
%6 = tt.splat %cast3 : (i32) -> tensor<256xi32>
%ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 4
// source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 1
%arg2_iter = arith.addi %arg2, %c3 : index
%arg3_iter = arith.addi %arg3, %c3 : index
%arg4_iter = arith.addi %arg4, %c3 : index
Expand All @@ -35,37 +35,36 @@ module {
%cast8 = arith.index_cast %8 : index to i32
%9 = tt.splat %cast8 : (i32) -> tensor<256xi32>
%ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 4
// source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 1
scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr<bf16>>, index, tensor<256x!tt.ptr<bf16>>, index
}
tt.return
}
}
// CHECK-LABEL: func.func @kernel(
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) {
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index
// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_14:.*]]:7 = scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_9]], %[[VAL_19:.*]] = %[[VAL_13]], %[[VAL_20:.*]] = %[[VAL_10]], %[[VAL_21:.*]] = %[[VAL_6]], %[[VAL_22:.*]] = %[[VAL_6]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) {
// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16>
// CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16>
// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16>
// CHECK: memref.tensor_store %[[VAL_24]], %[[VAL_19]] : memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index
// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index
// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_10]] : index
// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_20]], %[[VAL_10]] : index
// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index
// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]] : index
// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_22]], %[[VAL_31]] : index
// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_5]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>>
// CHECK: scf.yield %[[VAL_27]], %[[VAL_26]], %[[VAL_28]], %[[VAL_33]], %[[VAL_29]], %[[VAL_25]], %[[VAL_32]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index
// CHECK: }
// CHECK: return
Expand Down
10 changes: 5 additions & 5 deletions test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ module {
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%0 = tt.splat %arg0 : (!tt.ptr<bf16>) -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 4
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 1
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg0, sizes: 256, offsets: 1024, strides: 4
// source: arg0, sizes: 256, offsets: 1024, strides: 1
// gep operand is another gep' output, which is passed into the loop as varible, used after update
%_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr<bf16>>) {
// pointer updates
%4 = tt.splat %i_c3 : (i32) -> tensor<256xi32>
// sizes: 256, offsets: 3, strides: 0
%ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg0, sizes: 256, offsets: 1024 + i, strides: 4
// source: arg0, sizes: 256, offsets: 1024 + i, strides: 1
// perform load
%3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256xbf16>
tt.store %ptr_iter, %3 : tensor<256xbf16>
Expand Down Expand Up @@ -80,7 +80,7 @@ module {
}
// CHECK-LABEL: func.func @kernel(
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) {
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ module {
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%0 = tt.splat %arg0 : (!tt.ptr<bf16>) -> tensor<256x!tt.ptr<bf16>>
%1 = tt.make_range {end = 2048 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 4
%1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32>
// source: null, sizes: 256, offsets: 1024, strides: 1
%2 = tt.addptr %0, %1 : tensor<256x!tt.ptr<bf16>>, tensor<256xi32>
// source: arg0, sizes: 256, offsets: 1024, strides: 4
// source: arg0, sizes: 256, offsets: 1024, strides: 1
// Example 2, gep operand is another gep's output, which is passed into the loop as varible, used before update
%_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr<bf16>>) {
// perform load
Expand All @@ -36,7 +36,7 @@ module {
}
// CHECK-LABEL: func.func @kernel(
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) {
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index
Expand Down
13 changes: 7 additions & 6 deletions test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ module {
//%3: splat(%0) + range(0, 1024)
//%3: offset = %0, size = 1024, stride = 1
// vector and scalar are both constant
%4 = tt.make_range {end = 4096 : i32, start = 2048 : i32}:tensor<1024xi32>
%4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32>
%c10 = arith.constant 10 : i32
%5 = tt.splat %c10 : (i32) -> tensor<1024xi32>
%6 = arith.muli %5, %4 : tensor<1024xi32>
//%6: splat(%c10)*range(2048, 4096);
//%6: offset = %c10*2048, size = 1024, stride = %c10*2
//%6: offset = %c10*2048, size = 1024, stride = %c10*1
%7 = arith.addi %3, %6 : tensor<1024xi32>
//%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*2+1
//%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*1+1
%8 = tt.splat %arg0 : (!tt.ptr<bf16>) -> tensor<1024x!tt.ptr<bf16>>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<bf16>>, tensor<1024xi32>
//source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*2+1
//source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*1+1
%10 = tt.splat %arg1 : (!tt.ptr<bf16>) -> tensor<1024x!tt.ptr<bf16>>
%11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr<bf16>>, tensor<1024xi32>
//source=%arg1, offset = pid0, size = 1024, stride = 1
Expand All @@ -34,14 +34,15 @@ module {
}
// CHECK-LABEL: func.func @kernel(
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) {
// CHECK: %[[VAL_6:.*]] = arith.constant 11 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 20480 : index
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : index
// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}21] : memref<*xbf16> to memref<1024xbf16, strided<[21], offset: ?>>
// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}%[[VAL_6]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>>
// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1024xbf16>
// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[21], offset: ?>> to memref<1024xbf16>
// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16>
// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16>
// CHECK: memref.tensor_store %[[VAL_14]], %[[VAL_12]] : memref<1024xbf16, strided<[1], offset: ?>>
// CHECK: return
Expand Down
14 changes: 6 additions & 8 deletions test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ module {
//%3: splat(%0) + range(0, 1024)
//%3: offset = %0, size = 1024, stride = 1
// vector is constant, scalar is value
%4 = tt.make_range {end = 4096 : i32, start = 2048 : i32}:tensor<1024xi32>
%4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32>
%6 = arith.muli %5, %4 : tensor<1024xi32>
//%6: splat(%arg2)*range(2048, 4096);
//%6: offset = %arg2*2048, size = 1024, stride = %arg2*2
//%6: splat(%arg2)*range(2048, 3072);
//%6: offset = %arg2*2048, size = 1024, stride = %arg2*1
%7 = arith.addi %3, %6 : tensor<1024xi32>
//%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*2+1
//%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*1+1
%8 = tt.splat %arg0 : (!tt.ptr<bf16>) -> tensor<1024x!tt.ptr<bf16>>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<bf16>>, tensor<1024xi32>
//source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*2+1
//source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*1+1
%10 = tt.splat %arg1 : (!tt.ptr<bf16>) -> tensor<1024x!tt.ptr<bf16>>
%11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr<bf16>>, tensor<1024xi32>
//source=%arg1: offset = pid0, size = 1024, stride = 1
Expand All @@ -34,14 +34,12 @@ module {
// CHECK-LABEL: func.func @kernel(
// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) {
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2048 : index
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[ARG_6]] : i32 to index
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_2]] : i32 to index
// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index
// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_6]] : index
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_10]], %[[VAL_6]] : index
// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [1024], strides: {{\[}}%[[VAL_14]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>>
// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[ARG_6]] : i32 to index
// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_16]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>>
Expand Down
Loading

0 comments on commit f90b031

Please sign in to comment.