Skip to content

Commit

Permalink
Accelerate LU tile when input has one chunk (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored and wjsi committed Jan 9, 2020
1 parent 28ca58b commit c68ed39
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions mars/tensor/linalg/lu.py
Expand Up @@ -73,8 +73,35 @@ def __call__(self, a):
])
return ExecutableTuple([p, l, u])

@classmethod
def _tile_one_chunk(cls, op):
p, l, u = op.outputs
chunk_op = op.copy().reset_key()
chunk_kws = [
{'side': 'p', 'dtype': p.dtype,
'shape': p.shape, 'order': p.order,
'index': (0,) * p.ndim},
{'side': 'l', 'dtype': l.dtype,
'shape': l.shape, 'order': l.order,
'index': (0,) * l.ndim},
{'side': 'u', 'dtype': u.dtype,
'shape': u.shape, 'order': u.order,
'index': (0,) * u.ndim}
]
chunks = chunk_op.new_chunks(op.input.chunks, kws=chunk_kws)

new_op = op.copy()
kws = [p.params, l.params, u.params]
for i, out in enumerate([p, l, u]):
kws[i]['nsplits'] = tuple((s,) for s in out.shape)
kws[i]['chunks'] = [chunks[i]]
return new_op.new_tensors(op.inputs, kws=kws)

@classmethod
def tile(cls, op):
if len(op.input.chunks) == 1:
return cls._tile_one_chunk(op)

from ..arithmetic.subtract import TensorSubtract
from ..arithmetic.add import TensorTreeAdd
from ..base.transpose import TensorTranspose
Expand Down

0 comments on commit c68ed39

Please sign in to comment.