#map = affine_map<()[s0] -> (s0 * 16)> #map1 = affine_map<(d0, d1) -> (-d0 + 32, d1)> #map2 = affine_map<(d0, d1) -> (-d0 + 64, d1)> #map3 = affine_map<()[s0, s1] -> (s0 * 16 + s1)> #map4 = affine_map<()[s0, s1] -> (s0 * 64 + s1)> module { func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} { %cst = arith.constant dense<0.000000e+00> : vector<[4]xf16> %cst_0 = arith.constant dense : vector<2x[4]xi1> %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c8_i32 = arith.constant 8 : i32 %c32_i32 = arith.constant 32 : i32 %c64_i32 = arith.constant 64 : i32 %c16_i32 = arith.constant 16 : i32 %cst_1 = arith.constant 0.000000e+00 : f32 %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c31_i32 = arith.constant 31 : i32 %c63_i32 = arith.constant 63 : i32 %c15_i32 = arith.constant 15 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %cst_2 = arith.constant 0.000000e+00 : f16 %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> scf.for %arg15 = %c0 to %c32 step %c1 { scf.for %arg16 = %c0 to %c64 step %c1 { memref.store %cst_1, %alloc[%arg15, %arg16] : memref<32x64xf32> } } %0 = arith.addi %arg3, %c31_i32 : i32 %1 = arith.divsi %0, %c32_i32 : i32 %2 = arith.addi %arg4, %c63_i32 : i32 %3 = arith.divsi %2, %c64_i32 : i32 %4 = arith.muli %3, %c8_i32 : i32 %5 = arith.divsi %arg12, %4 : i32 %6 = arith.muli %5, %c8_i32 : i32 %7 = arith.subi %1, %6 : i32 %8 = arith.minsi %7, %c8_i32 : i32 %9 = arith.remsi %arg12, %8 : i32 %10 = arith.addi %6, %9 : i32 %11 = arith.remsi %arg12, %4 : i32 %12 = arith.divsi %11, %8 : i32 %13 = arith.muli %10, %c32_i32 : i32 %14 = arith.index_cast %13 : i32 to index %15 = arith.muli %12, %c64_i32 : i32 %16 = arith.index_cast %15 : i32 to index %17 = arith.index_cast %arg3 : i32 to index %18 = arith.index_cast %arg6 : i32 to index %19 = arith.muli %14, %18 : index %20 = arith.muli %17, %18 : index %21 = arith.index_cast %arg7 : i32 to index %22 = arith.index_cast %arg4 : i32 to index %23 = arith.addi %arg5, %c15_i32 : i32 %24 = arith.divsi %23, %c16_i32 : i32 %25 = arith.muli %arg7, %c16_i32 : i32 %26 = arith.index_cast %25 : i32 to index %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> memref.copy %alloc, %alloc_3 : memref<32x64xf32> to memref<32x64xf32> %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_3, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 { %39 = arith.addi %arg18, %16 : index %40 = arith.remsi %39, %22 : index %41 = arith.subi %39, %40 : index %42 = arith.addi %40, %c64 : index %43 = arith.minsi %42, %22 : index %44 = arith.subi %43, %40 : index %reinterpret_cast_7 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>> %45 = arith.subi %c64, %44 : index %reinterpret_cast_8 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>> %46 = arith.remsi %arg17, %18 : index %47 = arith.addi %20, %46 : index %48 = arith.subi %47, %arg17 : index %49 = arith.divsi %48, %18 : index %reinterpret_cast_9 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref> %50 = arith.subi %c32, %49 : index %reinterpret_cast_10 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref> %51 = arith.muli %arg15, %c16_i32 : i32 %52 = arith.subi %arg5, %51 : i32 %53 = arith.index_cast %52 : i32 to index %54 = arith.minsi %53, %c16 : index %alloc_11 = memref.alloc() : memref<32x16xf16> %55 = arith.cmpi slt, %54, %c16 : index scf.if %55 { scf.for %arg19 = %c0 to %c32 step %c1 { scf.for %arg20 = %c0 to %c16 step %c1 { memref.store %cst_2, %alloc_11[%arg19, %arg20] : memref<32x16xf16> } } } %56 = arith.minsi %49, %c32 : index %57 = arith.subi %c32, %56 : index %base_buffer_12, %offset_13, %sizes_14:2, %strides_15:2 = memref.extract_strided_metadata %reinterpret_cast_9 : memref> -> memref, index, index, index, index, index %reinterpret_cast_16 = memref.reinterpret_cast %base_buffer_12 to offset: [%offset_13], sizes: [%56, %54], strides: [%strides_15#0, %strides_15#1] : memref to memref> %base_buffer_17, %offset_18, %sizes_19:2, %strides_20:2 = memref.extract_strided_metadata %reinterpret_cast_10 : memref> -> memref, index, index, index, index, index %reinterpret_cast_21 = memref.reinterpret_cast %base_buffer_17 to offset: [%offset_18], sizes: [%57, %54], strides: [%strides_20#0, %strides_20#1] : memref to memref> %reinterpret_cast_22 = memref.reinterpret_cast %alloc_11 to offset: [0], sizes: [%56, %54], strides: [16, 1] : memref<32x16xf16> to memref> %58 = affine.apply #map()[%56] %reinterpret_cast_23 = memref.reinterpret_cast %alloc_11 to offset: [%58], sizes: [%57, %54], strides: [16, 1] : memref<32x16xf16> to memref> memref.copy %reinterpret_cast_16, %reinterpret_cast_22 : memref> to memref> memref.copy %reinterpret_cast_21, %reinterpret_cast_23 : memref> to memref> %alloc_24 = memref.alloc() : memref<16x64xf16> %59 = arith.cmpi slt, %54, %c16 : index scf.if %59 { scf.for %arg19 = %c0 to %c16 step %c1 { scf.for %arg20 = %c0 to %c64 step %c1 { memref.store %cst_2, %alloc_24[%arg19, %arg20] : memref<16x64xf16> } } } %60 = arith.minsi %44, %c64 : index %61 = arith.subi %c64, %60 : index %base_buffer_25, %offset_26, %sizes_27:2, %strides_28:2 = memref.extract_strided_metadata %reinterpret_cast_7 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref, index, index, index, index, index %reinterpret_cast_29 = memref.reinterpret_cast %base_buffer_25 to offset: [%offset_26], sizes: [%54, %60], strides: [%strides_28#0, %strides_28#1] : memref to memref> %base_buffer_30, %offset_31, %sizes_32:2, %strides_33:2 = memref.extract_strided_metadata %reinterpret_cast_8 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref, index, index, index, index, index %reinterpret_cast_34 = memref.reinterpret_cast %base_buffer_30 to offset: [%offset_31], sizes: [%54, %61], strides: [%strides_33#0, %strides_33#1] : memref to memref> %reinterpret_cast_35 = memref.reinterpret_cast %alloc_24 to offset: [0], sizes: [%54, %60], strides: [64, 1] : memref<16x64xf16> to memref> %reinterpret_cast_36 = memref.reinterpret_cast %alloc_24 to offset: [%60], sizes: [%54, %61], strides: [64, 1] : memref<16x64xf16> to memref> memref.copy %reinterpret_cast_29, %reinterpret_cast_35 : memref> to memref> memref.copy %reinterpret_cast_34, %reinterpret_cast_36 : memref> to memref> %vscale = vector.vscale %62 = arith.muli %vscale, %c4 : index %63 = arith.muli %vscale, %c4 : index %alloc_37 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> memref.copy %alloc, %alloc_37 : memref<32x64xf32> to memref<32x64xf32> scf.for %arg19 = %c0 to %c32 step %62 { scf.for %arg20 = %c0 to %c64 step %63 { scf.for %arg21 = %c0 to %c16 step %c2 { %alloca = memref.alloca() : memref> %alloca_38 = memref.alloca() : memref> %alloca_39 = memref.alloca() : memref> %alloca_40 = memref.alloca() : memref> %66 = affine.min #map1(%arg19, %62) %67 = affine.min #map2(%arg20, %63) %68 = affine.min #map1(%arg19, %62) %69 = affine.min #map2(%arg20, %63) %70 = affine.apply #map3()[%arg19, %arg21] %reinterpret_cast_41 = memref.reinterpret_cast %alloc_11 to offset: [%70], sizes: [%66, 2], strides: [16, 1] : memref<32x16xf16> to memref> %71 = affine.apply #map4()[%arg21, %arg20] %reinterpret_cast_42 = memref.reinterpret_cast %alloc_24 to offset: [%71], sizes: [2, %67], strides: [64, 1] : memref<16x64xf16> to memref<2x?xf16, strided<[64, 1], offset: ?>> %72 = affine.apply #map4()[%arg19, %arg20] %reinterpret_cast_43 = memref.reinterpret_cast %alloc_37 to offset: [%72], sizes: [%68, %69], strides: [64, 1] : memref<32x64xf32> to memref> %73 = vector.create_mask %66, %c2 : vector<[4]x2xi1> %74 = vector.create_mask %67 : vector<[4]xi1> %75 = vector.insert %74, %cst_0 [0] : vector<[4]xi1> into vector<2x[4]xi1> %76 = vector.insert %74, %75 [1] : vector<[4]xi1> into vector<2x[4]xi1> memref.store %76, %alloca_38[] : memref> %77 = vector.type_cast %alloca : memref> to memref<2xvector<[4]xf16>> %78 = vector.type_cast %alloca_38 : memref> to memref<2xvector<[4]xi1>> scf.for %arg22 = %c0 to %c2 step %c1 { %113 = memref.load %78[%arg22] : memref<2xvector<[4]xi1>> %114 = vector.transfer_read %reinterpret_cast_42[%arg22, %c0], %cst_2, %113 {in_bounds = [true]} : memref<2x?xf16, strided<[64, 1], offset: ?>>, vector<[4]xf16> memref.store %114, %77[%arg22] : memref<2xvector<[4]xf16>> } %79 = memref.load %alloca[] : memref> %80 = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> %vscale_44 = vector.vscale %81 = arith.muli %vscale_44, %c4 : index %82 = arith.index_cast %66 : index to i64 %83 = arith.index_cast %81 : index to i64 %84 = arith.minsi %82, %83 : i64 %85 = arith.index_cast %84 : i64 to index %86 = vector.create_mask %67 : vector<[4]xi1> %87 = scf.for %arg22 = %c0 to %85 step %c1 iter_args(%arg23 = %80) -> (vector<[4]x[4]xf32>) { %113 = arm_sme.load_tile_slice %reinterpret_cast_43[%arg22, %c0], %86, %arg23, %arg22 {tile_id = 0 : i32} : memref>, vector<[4]xi1>, vector<[4]x[4]xf32> scf.yield %113 : vector<[4]x[4]xf32> } %vscale_45 = vector.vscale %88 = arith.muli %vscale_45, %c4 : index %subview = memref.subview %reinterpret_cast_41[%c0, %c0] [%88, %c2] [%c1, %c1] : memref> to memref> %89 = vector.transpose %73, [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1> %transpose = memref.transpose %subview (d0, d1) -> (d1, d0) : memref> to memref> memref.store %89, %alloca_40[] : memref> %90 = vector.type_cast %alloca_39 : memref> to memref<2xvector<[4]xf16>> %91 = vector.type_cast %alloca_40 : memref> to memref<2xvector<[4]xi1>> scf.for %arg22 = %c0 to %c2 step %c1 { %113 = memref.load %91[%arg22] : memref<2xvector<[4]xi1>> %vscale_48 = vector.vscale %114 = arith.muli %vscale_48, %c4 : index %115 = scf.for %arg23 = %c0 to %114 step %c1 iter_args(%arg24 = %cst) -> (vector<[4]xf16>) { %116 = vector.extractelement %113[%arg23 : index] : vector<[4]xi1> %117 = scf.if %116 -> (vector<[4]xf16>) { %118 = memref.load %transpose[%arg22, %arg23] : memref> %119 = vector.insertelement %118, %arg24[%arg23 : index] : vector<[4]xf16> scf.yield %119 : vector<[4]xf16> } else { scf.yield %arg24 : vector<[4]xf16> } scf.yield %117 : vector<[4]xf16> } memref.store %115, %90[%arg22] : memref<2xvector<[4]xf16>> } %92 = memref.load %alloca_39[] : memref> %93 = vector.extract %92[0] : vector<[4]xf16> from vector<2x[4]xf16> %94 = vector.extract %79[0] : vector<[4]xf16> from vector<2x[4]xf16> %95 = vector.create_mask %66 : vector<[4]xi1> %96 = vector.create_mask %67 : vector<[4]xi1> %97 = vector.extract %92[1] : vector<[4]xf16> from vector<2x[4]xf16> %98 = vector.extract %79[1] : vector<[4]xf16> from vector<2x[4]xf16> %99 = vector.create_mask %66 : vector<[4]xi1> %100 = vector.create_mask %67 : vector<[4]xi1> %101 = "llvm.intr.experimental.vector.interleave2"(%93, %97) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> %102 = "llvm.intr.experimental.vector.interleave2"(%94, %98) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> %103 = "llvm.intr.experimental.vector.interleave2"(%95, %99) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> %104 = "llvm.intr.experimental.vector.interleave2"(%96, %100) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> %105 = arm_sme.fmopa_2way %101, %102 acc(%87) masks(%103, %104) {tile_id = 0 : i32} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> %vscale_46 = vector.vscale %106 = arith.muli %vscale_46, %c4 : index %107 = arith.index_cast %66 : index to i64 %108 = arith.index_cast %106 : index to i64 %109 = arith.minsi %107, %108 : i64 %110 = arith.index_cast %109 : i64 to index %111 = vector.create_mask %67 : vector<[4]xi1> scf.for %arg22 = %c0 to %110 step %c1 { arm_sme.store_tile_slice %105, %arg22, %111, %reinterpret_cast_43[%arg22, %c0] {tile_id = 0 : i32} : memref>, vector<[4]xi1>, vector<[4]x[4]xf32> } %112 = affine.apply #map4()[%arg19, %arg20] %reinterpret_cast_47 = memref.reinterpret_cast %alloc_37 to offset: [%112], sizes: [%68, %69], strides: [64, 1] : memref<32x64xf32> to memref> memref.copy %reinterpret_cast_43, %reinterpret_cast_47 : memref> to memref> } } } scf.for %arg19 = %c0 to %c32 step %c1 { scf.for %arg20 = %c0 to %c64 step %c1 { %66 = memref.load %alloc_37[%arg19, %arg20] : memref<32x64xf32> %67 = memref.load %arg16[%arg19, %arg20] : memref<32x64xf32> %68 = arith.addf %66, %67 : f32 memref.store %68, %alloc_37[%arg19, %arg20] : memref<32x64xf32> } } %64 = arith.addi %arg17, %c16 : index %65 = arith.addi %arg18, %26 : index scf.yield %alloc_37, %64, %65 : memref<32x64xf32>, index, index } %28 = arith.index_cast %arg8 : i32 to index %29 = arith.muli %14, %28 : index %30 = arith.addi %29, %16 : index %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>> %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16> scf.for %arg15 = %c0 to %c32 step %c1 { scf.for %arg16 = %c0 to %c64 step %c1 { %39 = memref.load %27#0[%arg15, %arg16] : memref<32x64xf32> %40 = arith.truncf %39 : f32 to f16 memref.store %40, %alloc_4[%arg15, %arg16] : memref<32x64xf16> } } %31 = arith.addi %14, %c32 : index %32 = arith.minsi %31, %17 : index %33 = arith.subi %32, %14 : index %34 = arith.addi %16, %c64 : index %35 = arith.minsi %34, %22 : index %36 = arith.subi %35, %16 : index %37 = arith.minsi %33, %c32 : index %38 = arith.minsi %36, %c64 : index %reinterpret_cast_5 = memref.reinterpret_cast %alloc_4 to offset: [0], sizes: [%37, %38], strides: [64, 1] : memref<32x64xf16> to memref> %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %reinterpret_cast : memref<32x64xf16, strided<[?, 1], offset: ?>> -> memref, index, index, index, index, index %reinterpret_cast_6 = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [%37, %38], strides: [%strides#0, 1] : memref to memref> memref.copy %reinterpret_cast_5, %reinterpret_cast_6 : memref> to memref> return } }