Skip to content

numpy.einsum new feature request: repeated output subscripts as diagonal #4965

@PierreAndreNoel

Description

@PierreAndreNoel

I think that the following new feature would make numpy.einsum even more powerful/useful/awesome than it already is. Moreover, the change should not interfere with existing code, it would preserve the "minimalistic" spirit of numpy.einsum, and the new functionality would integrate in a seamless/intuitive manner for the users.

In short, the new feature would allow for repeated subscripts to appear in the "output" part of the subscripts parameter (i.e., on the right-hand side of ->). The corresponding dimensions in the resulting ndarray would only be filled along their diagonal, leaving the off diagonal entries to the default value for this dtype (typically zero). Note that the current behavior is to raise an exception when repeated output subscripts are being used.

This is simplest to describe with an example involving the dual behavior of numpy.diag.

# Extracting the diagonal of a 2-D array.
A = arange(16).reshape(4,4)
print(diag(A)) # Output: [ 0 5 10 15 ]
print(einsum('ii->i', A)) # Same as previous line (current behavior).

# Constructing a diagonal 2-D array.
v = arange(4)
print(diag(v)) # Output: [[0 0 0 0] [0 1 0 0] [0 0 2 0] [0 0 0 3]]
print(einsum('i->ii', v)) # New behavior would be same as previous line.
# The current behavior of the previous line is to raise an exception.

By opposition to numpy.diag, the approach generalizes to higher dimensions: einsum('iii->i', A) extracts the diagonal of a 3-D array, and einsum('i->iii', v) would build a diagonal 3-D array.

The proposed behavior really starts to shine in more intricate cases.

# Dummy values, these should be probabilities to make sense below.
P_w_ab = arange(24).reshape(3,2,4)
P_y_wxab = arange(144).reshape(3,3,2,2,4)

# With the proposed behavior, the following two lines should be equivalent.
P_xyz_ab = einsum('wab,xa,ywxab,zy->xyzab', P_w_ab, eye(2), P_y_wxab, eye(3))
also_P_xyz_ab = einsum('wab,ywaab->ayyab', P_w_ab, P_y_wxab)

If this is not convincing enough, replace eye(2) by eye(P_w_ab.shape[1]) and replace eye(3) by eye(P_y_wxab.shape[0]), then imagine more dimensions and repeated indices... The new notation would allow for crisper codes and reduce the opportunities for dumb mistakes.

For those who wonder, the above computation amounts to $P(X=x,Y=y,Z=z|A=a,B=b) = \sum_w P(W=w|A=a,B=b) P(X=x|A=a) P(Y=y|W=w,X=x,A=a,B=b) P(Z=z|Y=y)$ with $P(X=x|A=a)=\delta_{xa}$ and $P(Z=z|Y=y)=\delta_{zy}$ (using LaTeX notation, and $\delta_{ij}$ is Kronecker's delta).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions