-
-
Notifications
You must be signed in to change notification settings - Fork 833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove memory copy in matmul #6179
Conversation
@@ -449,7 +449,7 @@ cpdef ndarray tensordot_core( | |||
out = _ndarray_init(ret_shape, dtype) | |||
else: | |||
if out.dtype != dtype: | |||
out = _ndarray_init(ret_shape, dtype) | |||
raise NotImplementedError("The out array dtype is mismatched") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the current implementation, which ignores out
, seems wrong. I'm not sure if the change is better than commenting # TODO: Fix to write to out
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW NumPy's matmul
will just error in this case (though it does allow casting to lower precision so float64
to float32
).
import numpy as np
a = np.random.random((2, 3))
b = np.random.random((3, 2))
c = np.empty((2, 2), dtype=int)
np.matmul(a, b, out=c)
---------------------------------------------------------------------------
UFuncTypeError Traceback (most recent call last)
<ipython-input-7-a4f34170f335> in <module>
5 c = np.empty((2, 2), dtype=int)
6
----> 7 np.matmul(a, b, out=c)
UFuncTypeError: Cannot cast ufunc 'matmul' output from dtype('float64') to dtype('int64') with casting rule 'same_kind'
Note: UFuncTypeError
is just a TypeError
subclass
@@ -10,7 +10,8 @@ | |||
from cupy.linalg import _util | |||
|
|||
_gu_func_matmul = _GUFunc( | |||
_core.matmul, '(n?,k),(k,m?)->(n?,m?)', supports_batched=True) | |||
_core.matmul, '(n?,k),(k,m?)->(n?,m?)', supports_batched=True, | |||
supports_out=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This is necessary to eliminate copy operations. The general out
support in cupy._core._gufuncs._GUFunc
cannot know C-contiguous output is assumed at the cublas call in cupy._core.matmul
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that there is no problem in normal usage. Do you think you need to take any measures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to supports_out=False
(default), cupy.matmul
did not hit this NotImplementedError. Thus, for correctness, the out
support should be perfect to declare supports_out=True
.
Please test something like
out = xp.zeros((2, 4), xp.float32)[::-1]
return xp.matmul(xp.ones((2, 3)), xp.ones((3, 4)), out=out)
and
out = xp.zeros((2, 4), bool)
xp.matmul(xp.ones((2, 3)), xp.ones((3, 4)), out=out, casting='unsafe')
BTW, I found a bug that cupy.matmul
returns out
's view instead of out
.
@@ -10,7 +10,8 @@ | |||
from cupy.linalg import _util | |||
|
|||
_gu_func_matmul = _GUFunc( | |||
_core.matmul, '(n?,k),(k,m?)->(n?,m?)', supports_batched=True) | |||
_core.matmul, '(n?,k),(k,m?)->(n?,m?)', supports_batched=True, | |||
supports_out=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to supports_out=False
(default), cupy.matmul
did not hit this NotImplementedError. Thus, for correctness, the out
support should be perfect to declare supports_out=True
.
Please test something like
out = xp.zeros((2, 4), xp.float32)[::-1]
return xp.matmul(xp.ones((2, 3)), xp.ones((3, 4)), out=out)
and
out = xp.zeros((2, 4), bool)
xp.matmul(xp.ones((2, 3)), xp.ones((3, 4)), out=out, casting='unsafe')
BTW, I found a bug that cupy.matmul
returns out
's view instead of out
.
Co-authored-by: Toshiki Kataoka <tos.lunar@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/test mini |
Co-authored-by: Toshiki Kataoka <tos.lunar@gmail.com>
/test mini |
Remove memory copy in matmul
The current implementation incurs extra memory consumption and memory copy in matmul operation. This PR solves it.