-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
Description
I think that the following new feature would make numpy.einsum
even more powerful/useful/awesome than it already is. Moreover, the change should not interfere with existing code, it would preserve the "minimalistic" spirit of numpy.einsum
, and the new functionality would integrate in a seamless/intuitive manner for the users.
In short, the new feature would allow for repeated subscripts to appear in the "output" part of the subscripts
parameter (i.e., on the right-hand side of ->
). The corresponding dimensions in the resulting ndarray
would only be filled along their diagonal, leaving the off diagonal entries to the default value for this dtype
(typically zero). Note that the current behavior is to raise an exception when repeated output subscripts are being used.
This is simplest to describe with an example involving the dual behavior of numpy.diag
.
# Extracting the diagonal of a 2-D array.
A = arange(16).reshape(4,4)
print(diag(A)) # Output: [ 0 5 10 15 ]
print(einsum('ii->i', A)) # Same as previous line (current behavior).
# Constructing a diagonal 2-D array.
v = arange(4)
print(diag(v)) # Output: [[0 0 0 0] [0 1 0 0] [0 0 2 0] [0 0 0 3]]
print(einsum('i->ii', v)) # New behavior would be same as previous line.
# The current behavior of the previous line is to raise an exception.
By opposition to numpy.diag
, the approach generalizes to higher dimensions: einsum('iii->i', A)
extracts the diagonal of a 3-D array, and einsum('i->iii', v)
would build a diagonal 3-D array.
The proposed behavior really starts to shine in more intricate cases.
# Dummy values, these should be probabilities to make sense below.
P_w_ab = arange(24).reshape(3,2,4)
P_y_wxab = arange(144).reshape(3,3,2,2,4)
# With the proposed behavior, the following two lines should be equivalent.
P_xyz_ab = einsum('wab,xa,ywxab,zy->xyzab', P_w_ab, eye(2), P_y_wxab, eye(3))
also_P_xyz_ab = einsum('wab,ywaab->ayyab', P_w_ab, P_y_wxab)
If this is not convincing enough, replace eye(2)
by eye(P_w_ab.shape[1])
and replace eye(3)
by eye(P_y_wxab.shape[0])
, then imagine more dimensions and repeated indices... The new notation would allow for crisper codes and reduce the opportunities for dumb mistakes.
For those who wonder, the above computation amounts to