-
Notifications
You must be signed in to change notification settings - Fork 11.6k
/
sparse_matmul_codegen.mlir
151 lines (147 loc) · 11.7 KB
/
sparse_matmul_codegen.mlir
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --sparsification --sparse-tensor-codegen \
// RUN: --canonicalize --cse | FileCheck %s
#CSR = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
dimOrdering = affine_map<(i,j) -> (i,j)>
}>
//
// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
//
// CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0(
// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>,
// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_3:.*]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_4:.*]]: memref<?xf64>,
// CHECK-SAME: %[[VAL_5:[^ ]+]]: index,
// CHECK-SAME: %[[VAL_6:.*]]: index,
// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex>
// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index
// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<?xindex>
// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index
// CHECK: scf.yield %[[VAL_18]] : i1
// CHECK: } else {
// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: scf.yield %[[VAL_8]] : i1
// CHECK: }
// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref<?xindex>) {
// CHECK: scf.yield %[[VAL_3]] : memref<?xindex>
// CHECK: } else {
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index
// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<?xindex>
// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
// CHECK: scf.yield %[[VAL_22]] : memref<?xindex>
// CHECK: }
// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
// CHECK-LABEL: func.func @matmul(
// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>,
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>,
// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>,
// CHECK-SAME: %[[VAL_7:.*7]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_8:.*8]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_9:.*9]]: memref<?xf64>) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true
// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>)
// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex>
// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex>
// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index, index
// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64>
// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1>
// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex>
// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref<?xindex>
// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>)
// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>)
// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref<?xindex>
// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) {
// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref<?xindex>
// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref<?xf64>
// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index
// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref<?xindex>
// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) {
// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref<?xindex>
// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref<?xf64>
// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64
// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64
// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1
// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) {
// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1>
// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex>
// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index
// CHECK: scf.yield %[[VAL_59]] : index
// CHECK: } else {
// CHECK: scf.yield %[[VAL_50]] : index
// CHECK: }
// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64>
// CHECK: scf.yield %[[VAL_60:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref<?xindex>
// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex>
// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>)
// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64>
// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1>
// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64>
// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1>
// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex>
// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex>
// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref<?xindex>
// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) {
// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index
// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index
// CHECK: scf.if %[[VAL_81]] {
// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref<?xindex>
// CHECK: }
// CHECK: scf.yield %[[VAL_82]] : index
// CHECK: }
// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
// CHECK: }
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
%B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
%C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
%D = linalg.matmul
ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
return %D: tensor<4x4xf64, #CSR>
}