forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sparse.py
94 lines (77 loc) · 3.18 KB
/
sparse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Sparse operators"""
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from ..util import traverse_inline
from .. import nn
@autotvm.register_topi_compute("sparse_dense.cuda")
def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.te.Tensor
2-D with shape [M, K], float32
weight_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
weight_indices : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)
weight_indptr : tvm.te.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)
Returns
-------
output : tvm.te.Tensor
2-D with shape [M, N]
"""
# pylint:disable=unused-argument
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr)
@autotvm.register_topi_schedule("sparse_dense.cuda")
def schedule_sparse_dense(cfg, outs):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
out = s.outputs[0].output(0)
(_, c) = s[y_bsrmm].op.reduce_axis
(m_o, n_o) = s[out].op.axis
s[out].bind(m_o, te.thread_axis("blockIdx.x"))
s[out].bind(n_o, te.thread_axis("blockIdx.y"))
s[y_bsrmm].compute_at(s[out], n_o)
thread_x = te.thread_axis("threadIdx.x")
cfg.define_split("tile_c", c, num_outputs=2)
if cfg.is_fallback:
cfg["tile_c"] = SplitEntity([-1, 8])
_, ci = cfg['tile_c'].apply(s, y_bsrmm, c)
y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
tx = s[y_bsrmm].op.reduce_axis[0]
s[y_bsrmm].bind(tx, thread_x)
s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
s[out].set_store_predicate(thread_x.var.equal(0))
traverse_inline(s, outs[0].op, _callback)
return s