|
| 1 | +// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize -canonicalize | FileCheck %s |
| 2 | + |
| 3 | +// CHECK-LABEL: test_vector_insert_2d_idx |
| 4 | +// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32> |
| 5 | +// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> |
| 6 | +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[SRC]] |
| 7 | +// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 64, 65, 66, 67, 16, 17, 18, 19, 20, 21, |
| 8 | +// CHECK-SAME: 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, |
| 9 | +// CHECK-SAME: 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<4xf32> |
| 10 | +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> |
| 11 | +// CHECK: return %[[RES]] : vector<2x8x4xf32> |
| 12 | +func.func @test_vector_insert_2d_idx(%arg0: vector<2x8x4xf32>, %arg1: vector<4xf32>) -> vector<2x8x4xf32> { |
| 13 | + %0 = vector.insert %arg1, %arg0[0, 3]: vector<4xf32> into vector<2x8x4xf32> |
| 14 | + return %0 : vector<2x8x4xf32> |
| 15 | +} |
| 16 | + |
| 17 | +// ----- |
| 18 | +// CHECK-LABEL: test_vector_transpose |
| 19 | +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32> |
| 20 | +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32> |
| 21 | +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] |
| 22 | +// CHECK: [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<16xf32>, vector<16xf32> |
| 23 | +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> |
| 24 | +// CHECK: return %[[RES]] : vector<8x2xf32> |
| 25 | +func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> { |
| 26 | + %0 = vector.transpose %arg, [1, 0] : vector<2x8xf32> to vector<8x2xf32> |
| 27 | + return %0 : vector<8x2xf32> |
| 28 | +} |
| 29 | + |
| 30 | +// ----- |
| 31 | +// CHECK-LABEL: test_vector_transpose_16x16 |
| 32 | +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> |
| 33 | +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> |
| 34 | +// CHECK-62: vector.shuffle |
| 35 | +func.func @test_vector_transpose_16x16(%arg: vector<16x16xf32>) -> vector<16x16xf32> { |
| 36 | + %0 = vector.transpose %arg, [1, 0] : vector<16x16xf32> to vector<16x16xf32> |
| 37 | + return %0 : vector<16x16xf32> |
| 38 | +} |
| 39 | + |
| 40 | +// ----- |
| 41 | + |
| 42 | +// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16 |
| 43 | +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf16>) |
| 44 | +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index |
| 45 | +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index |
| 46 | +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index |
| 47 | +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index |
| 48 | +// CHECK: %[[LOAD0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 49 | +// CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 50 | +// CHECK: %[[LOAD2:.*]] = vector.load %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 51 | +// CHECK: %[[LOAD3:.*]] = vector.load %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 52 | +// CHECK: vector.store %[[LOAD0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 53 | +// CHECK: vector.store %[[LOAD1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 54 | +// CHECK: vector.store %[[LOAD2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 55 | +// CHECK: vector.store %[[LOAD3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> |
| 56 | +func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) { |
| 57 | + %c0 = arith.constant 0 : index |
| 58 | + %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> |
| 59 | + vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> |
| 60 | + return |
| 61 | +} |
| 62 | + |
| 63 | +// ----- |
| 64 | +// CHECK-LABEL: func.func @test_vector_store_load_4x4x4 |
| 65 | +// CHECK-SAME: (%[[BUF:.*]]: memref<4x4x4xf32>) |
| 66 | +// Constants (order not important) |
| 67 | +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index |
| 68 | +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index |
| 69 | +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index |
| 70 | +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index |
| 71 | +// All 16 scalar-slice (row/col plane) loads of 1D vectors |
| 72 | +// CHECK-COUNT-16: vector.load {{.*}} : memref<4x4x4xf32>, vector<4xf32> |
| 73 | +// No remaining 3D vector load |
| 74 | +// CHECK-NOT: vector.load {{.*}} : memref<4x4x4xf32>, vector<4x4x4xf32> |
| 75 | +// All 16 stores of 1D vectors |
| 76 | +// CHECK-COUNT-16: vector.store {{.*}} : memref<4x4x4xf32>, vector<4xf32> |
| 77 | +// CHECK: return |
| 78 | +func.func @test_vector_store_load_4x4x4(%buffer: memref<4x4x4xf32>) { |
| 79 | + %c0 = arith.constant 0 : index |
| 80 | + %0 = vector.load %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32> |
| 81 | + vector.store %0, %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32> |
| 82 | + return |
| 83 | +} |
| 84 | + |
| 85 | +// ----- |
| 86 | +// CHECK-LABEL: func.func @test_linearize_index |
| 87 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> |
| 88 | +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> |
| 89 | +// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32> |
| 90 | +// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex> |
| 91 | +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST2]], %[[CST]] : vector<4xindex> |
| 92 | +// CHECK: %[[INDEX_CAST1:.*]] = arith.index_cast %[[ADDI]] : vector<4xindex> to vector<4xi32> |
| 93 | +// CHECK: %[[MULI:.*]] = arith.muli %[[INDEX_CAST1]], %[[CAST1]] : vector<4xi32> |
| 94 | +// CHECK: %[[INDEX_CAST2:.*]] = arith.index_cast %[[MULI]] : vector<4xi32> to vector<4xindex> |
| 95 | +// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[INDEX_CAST2]] : vector<4xindex> to vector<2x2xindex> |
| 96 | +// CHECK: return %[[RESULT]] : vector<2x2xindex> |
| 97 | +func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> { |
| 98 | + %0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex> |
| 99 | + // Arith and math ops are handled in generic way, check some of them |
| 100 | + %1 = arith.addi %arg0, %0 : vector<2x2xindex> |
| 101 | + %2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32> |
| 102 | + %3 = arith.muli %2, %arg1 : vector<2x2xi32> |
| 103 | + %4 = arith.index_cast %3 : vector<2x2xi32> to vector<2x2xindex> |
| 104 | + return %4 : vector<2x2xindex> |
| 105 | +} |
| 106 | + |
| 107 | +// ----- |
| 108 | +// CHECK-LABEL: func.func @broadcast_stretch_at_start |
| 109 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>) -> vector<3x4xf32> |
| 110 | +// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32> |
| 111 | +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<1x4xf32> to vector<4xf32> |
| 112 | +// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32> |
| 113 | +// CHECK: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[CAST]] [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32> |
| 114 | +// CHECK: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[CAST]] [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] : vector<12xf32>, vector<4xf32> |
| 115 | +// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<12xf32> to vector<3x4xf32> |
| 116 | +func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { |
| 117 | + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> |
| 118 | + return %0 : vector<3x4xf32> |
| 119 | +} |
| 120 | + |
| 121 | +// ----- |
| 122 | +// CHECK-LABEL: func.func @broadcast_stretch_at_end |
| 123 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1xf32>) -> vector<4x3xf32> |
| 124 | +// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32> |
| 125 | +// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<4x1xf32> |
| 126 | +// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[EXTRACT1]] : f32 to vector<3xf32> |
| 127 | +// CHECK: vector.shuffle |
| 128 | +// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<4x1xf32> |
| 129 | +// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[EXTRACT2]] : f32 to vector<3xf32> |
| 130 | +// CHECK: vector.shuffle |
| 131 | +// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG0]][2, 0] : f32 from vector<4x1xf32> |
| 132 | +// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[EXTRACT3]] : f32 to vector<3xf32> |
| 133 | +// CHECK: vector.shuffle |
| 134 | +// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][3, 0] : f32 from vector<4x1xf32> |
| 135 | +// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[EXTRACT4]] : f32 to vector<3xf32> |
| 136 | +// CHECK: vector.shuffle |
| 137 | +// CHECK: vector.shape_cast {{.*}} : vector<12xf32> to vector<4x3xf32> |
| 138 | +func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { |
| 139 | + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> |
| 140 | + return %0 : vector<4x3xf32> |
| 141 | +} |
| 142 | + |
| 143 | +// ----- |
| 144 | +// CHECK-LABEL: func.func @broadcast_stretch_in_middle |
| 145 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> |
| 146 | +// CHECK: ub.poison : vector<6xf32> |
| 147 | +// CHECK: ub.poison : vector<24xf32> |
| 148 | +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1x2xf32> to vector<8xf32> |
| 149 | +// CHECK-COUNT-20: vector.shuffle |
| 150 | +// CHECK: vector.shape_cast {{.*}} : vector<24xf32> to vector<4x3x2xf32> |
| 151 | +// CHECK-NOT: vector.broadcast |
| 152 | +func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { |
| 153 | + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> |
| 154 | + return %0 : vector<4x3x2xf32> |
| 155 | +} |
| 156 | + |
| 157 | +// CHECK-LABEL: func.func @gather_memref_2d |
| 158 | +// CHECK-SAME: (%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> { |
| 159 | + |
| 160 | +// CHECK: %0 = ub.poison : vector<6xf32> |
| 161 | +// CHECK: %c1 = arith.constant 1 : index |
| 162 | +// CHECK: %c0 = arith.constant 0 : index |
| 163 | +// CHECK: %1 = vector.shape_cast %arg3 : vector<2x3xf32> to vector<6xf32> |
| 164 | + |
| 165 | +// First shuffle + if ladder for row 0 |
| 166 | +// CHECK: %2 = vector.shuffle %1, %1 [0, 1, 2] |
| 167 | +// CHECK: %3 = vector.extract %arg2[0, 0] |
| 168 | +// CHECK: %4 = vector.extract %arg1[0, 0] |
| 169 | +// CHECK: %5 = arith.addi %4, %c1 |
| 170 | +// CHECK: %6 = scf.if %3 -> (vector<3xf32>) { |
| 171 | +// CHECK: %{{.*}} = vector.load %arg0[%c0, %5] : memref<?x?xf32>, vector<1xf32> |
| 172 | +// CHECK: %{{.*}} = vector.extract {{.*}}[0] : f32 |
| 173 | +// CHECK: %{{.*}} = vector.insert {{.*}}, %2 [0] : f32 into vector<3xf32> |
| 174 | +// CHECK: scf.yield {{.*}} : vector<3xf32> |
| 175 | +// CHECK: } else { |
| 176 | +// CHECK: scf.yield %2 : vector<3xf32> |
| 177 | +// CHECK: } |
| 178 | + |
| 179 | +// CHECK: %7 = vector.extract %arg2[0, 1] |
| 180 | +// CHECK: %8 = vector.extract %arg1[0, 1] |
| 181 | +// CHECK: %9 = arith.addi %8, %c1 |
| 182 | +// CHECK: %10 = scf.if %7 -> (vector<3xf32>) |
| 183 | + |
| 184 | +// … (similar checks for the rest of row 0, then row 1) |
| 185 | + |
| 186 | +// CHECK: %15 = vector.shuffle %0, %{{.*}} [6, 7, 8, 3, 4, 5] |
| 187 | +// CHECK: %16 = vector.shuffle %1, %1 [3, 4, 5] |
| 188 | + |
| 189 | +// Row 1 if ladder checks |
| 190 | +// CHECK: %17 = vector.extract %arg2[1, 0] |
| 191 | +// CHECK: %18 = vector.extract %arg1[1, 0] |
| 192 | +// CHECK: %19 = arith.addi %18, %c1 |
| 193 | +// CHECK: %20 = scf.if %17 -> (vector<3xf32>) |
| 194 | + |
| 195 | +// … (similar checks for remaining row 1 inserts) |
| 196 | + |
| 197 | +// Final reshuffle and cast |
| 198 | +// CHECK: %29 = vector.shuffle %15, %{{.*}} [0, 1, 2, 6, 7, 8] |
| 199 | +// CHECK: %30 = vector.shape_cast %29 : vector<6xf32> to vector<2x3xf32> |
| 200 | +// CHECK: return %30 : vector<2x3xf32> |
| 201 | +func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> { |
| 202 | + %c0 = arith.constant 0 : index |
| 203 | + %c1 = arith.constant 1 : index |
| 204 | + %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> |
| 205 | + return %0 : vector<2x3xf32> |
| 206 | +} |
| 207 | + |
| 208 | +// ----- |
| 209 | +// Check for vector linearization interoperability with XeGPU dialect ops. |
| 210 | +// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops. |
| 211 | + |
| 212 | +// CHECK: gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel { |
| 213 | +// CHECK: %c0 = arith.constant 0 : index |
| 214 | +// CHECK: %cst = arith.constant dense<0.000000e+00> : vector<64xf16> |
| 215 | +// CHECK: %cst_0 = arith.constant dense<5.000000e+00> : vector<64xf32> |
| 216 | + |
| 217 | +// CHECK: %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] |
| 218 | +// CHECK: %1 = xegpu.load_nd %0 |
| 219 | +// CHECK: %2 = vector.shape_cast %1 : vector<8x16xf16> to vector<128xf16> |
| 220 | +// CHECK: %3 = vector.shuffle %2, %cst {{.*}} : vector<128xf16>, vector<64xf16> |
| 221 | +// CHECK: %4 = vector.shape_cast %3 : vector<128xf16> to vector<8x16xf16> |
| 222 | + |
| 223 | +// CHECK: %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0] |
| 224 | +// CHECK: %6 = xegpu.load_nd %5 |
| 225 | +// CHECK: %7 = vector.shape_cast %6 : vector<16x16xf16> to vector<256xf16> |
| 226 | +// CHECK: %8 = vector.shuffle %7, %cst {{.*}} : vector<256xf16>, vector<64xf16> |
| 227 | +// CHECK: %9 = vector.shape_cast %8 : vector<256xf16> to vector<16x16xf16> |
| 228 | + |
| 229 | +// CHECK: %10 = xegpu.dpas %4, %9 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> |
| 230 | +// CHECK: %11 = vector.shape_cast %10 : vector<8x16xf32> to vector<128xf32> |
| 231 | +// CHECK: %12 = vector.shuffle %11, %11 {{.*}} : vector<128xf32>, vector<128xf32> |
| 232 | +// CHECK: %13 = arith.addf %12, %cst_0 : vector<64xf32> |
| 233 | +// CHECK: %14 = vector.shuffle %11, %13 {{.*}} : vector<128xf32>, vector<64xf32> |
| 234 | +// CHECK: %15 = vector.shape_cast %14 : vector<128xf32> to vector<8x16xf32> |
| 235 | + |
| 236 | +// CHECK: %16 = xegpu.create_nd_tdesc %arg2[%c0, %c0] |
| 237 | +// CHECK: xegpu.store_nd %15, %16 |
| 238 | +// CHECK: gpu.return |
| 239 | + |
| 240 | +gpu.module @test_kernel { |
| 241 | + gpu.func @test_kernel(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %C: memref<8x16xf32>) kernel { |
| 242 | + %c0 = arith.constant 0 : index |
| 243 | + %cst_vec_0 = arith.constant dense<0.000000e+00> : vector<8x8xf16> |
| 244 | + %cst_vec_1 = arith.constant dense<0.000000e+00> : vector<8x8xf16> |
| 245 | + %cst_vec_2 = arith.constant dense<5.000000e+00> : vector<8x8xf32> |
| 246 | + %a_tdesc = xegpu.create_nd_tdesc %A[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> |
| 247 | + %a_val = xegpu.load_nd %a_tdesc : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> -> vector<8x16xf16> |
| 248 | + %a_val_0 = vector.insert_strided_slice %cst_vec_0, %a_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<8x16xf16> |
| 249 | + %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> |
| 250 | + |
| 251 | + %b_val = xegpu.load_nd %b_tdesc : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> -> vector<16x16xf16> |
| 252 | + %b_val_0 = vector.insert_strided_slice %cst_vec_1, %b_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<16x16xf16> |
| 253 | + %c_val = xegpu.dpas %a_val_0, %b_val_0 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> |
| 254 | + %c_val_0 = vector.extract_strided_slice %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> |
| 255 | + %c_addf = arith.addf %c_val_0, %cst_vec_2 : vector<8x8xf32> |
| 256 | + %c_result = vector.insert_strided_slice %c_addf, %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x8xf32> into vector<8x16xf32> |
| 257 | + %c_tdesc = xegpu.create_nd_tdesc %C[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1>> |
| 258 | + xegpu.store_nd %c_result, %c_tdesc : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> |
| 259 | + gpu.return |
| 260 | + } |
| 261 | +} |
| 262 | + |
| 263 | + |
0 commit comments