In [1]:
import numpy as np
import functools,operator

In [2]:
def contraction(A,B,indA,indB):
    indA=np.array(indA)
    indB=np.array(indB)
    sizeA=len(A.shape)
    sizeB=len(B.shape)
    indA2=np.setdiff1d(range(sizeA),indA)
    indB2=np.setdiff1d(range(sizeB),indB)
    Ashape=A.shape
    Bshape=B.shape
    A=np.moveaxis(A,indA2,range(len(indA2)))
    tempB=list(range(-len(indB2),0,1))
    B=np.moveaxis(B,indB2,tempB)
    #B=np.moveaxis(B,[0],[1]) #problem is here in move axis, in general I don't give the same ordering on thick leg
    Aorig = A.shape[:-len(indA)]
    Acontr= A.shape[-len(indA):]
    Borig = B.shape[len(indB):]
    Bcontr= B.shape[:len(indB)]
    try:
        dimsA1=functools.reduce(operator.mul,Aorig)
        dimsA2 = functools.reduce(operator.mul, Acontr)
        A=A.reshape((dimsA1,dimsA2))
    except:
        A=A.flatten()
    try:
        tempB=list(range(0,len(indB)))
        reverseds=tempB[::-1]
        B=np.moveaxis(B,tempB,reverseds)
        dimsB1 = functools.reduce(operator.mul, Borig)
        dimsB2 = functools.reduce(operator.mul, Bcontr)
        B=B.reshape((dimsB2,dimsB1))
    except:
        B=B.flatten()
    C=A@B
    try:
        newdims = np.concatenate((Aorig, Borig))
        return C.reshape(newdims)
    except:
        return C.reshape(Aorig)

In [3]:
Da = 10; # alpha()
Db = 12; # beta()
Dc = 14; # gamma()
Dd = 17; # delta
Dm = 20; # mu

In [4]:
A = np.random.rand(Dc,Dd);   # tensor A[gamma,delta]
A[A<0.8]=0 #to make it sparse
B = np.random.rand(Da,Dm,Dc); # tensor B[alpha,mu,gamma]
B[B<0.8]=0
C = np.random.rand(Db,Dm,Dd); # tensor C[beta,mu,delta]
C[C<0.8]=0

In [5]:
A.shape,B.shape

((14, 17), (10, 20, 14))

In [6]:
AB = contraction(A,B,[0],[2]);
ABC=contraction(AB,C,[0,2],[2,1]);

In [7]:
%%timeit
AB = contraction(A,B,[0],[2]);
ABC=contraction(AB,C,[0,2],[2,1]);


349 μs ± 33.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
AB_True=np.einsum('ji,klj->ikl',A,B)
ABC_True=np.einsum('lmn,onl->mo',AB_True,C)

In [9]:
%%timeit
AB_True=np.einsum('ji,klj->ikl',A,B)
ABC_True=np.einsum('lmn,onl->mo',AB_True,C)

50.2 μs ± 2.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
np.isclose(ABC,ABC_True).all()

True

In [11]:
import jax.numpy as jnp


In [12]:
def contraction_jax(A,B,indA,indB):
    indA=jnp.array(indA)
    indB=jnp.array(indB)
    sizeA=len(A.shape)
    sizeB=len(B.shape)
    indA2=jnp.setdiff1d(jnp.array(range(sizeA)),indA)
    indB2=jnp.setdiff1d(jnp.array(range(sizeB)),indB)
    Ashape=A.shape
    Bshape=B.shape
    A=jnp.moveaxis(A,indA2,range(len(indA2)))
    tempB=list(range(-len(indB2),0,1))
    B=jnp.moveaxis(B,indB2,tempB)
    #B=np.moveaxis(B,[0],[1]) #problem is here in move axis, in general I don't give the same ordering on thick leg
    Aorig = A.shape[:-len(indA)]
    Acontr= A.shape[-len(indA):]
    Borig = B.shape[len(indB):]
    Bcontr= B.shape[:len(indB)]
    try:
        dimsA1=functools.reduce(operator.mul,Aorig)
        dimsA2 = functools.reduce(operator.mul, Acontr)
        A=A.reshape((dimsA1,dimsA2))
    except:
        A=A.flatten()
    try:
        tempB=list(range(0,len(indB)))
        reverseds=tempB[::-1]
        B=jnp.moveaxis(B,tempB,reverseds)
        dimsB1 = functools.reduce(operator.mul, Borig)
        dimsB2 = functools.reduce(operator.mul, Bcontr)
        B=B.reshape((dimsB2,dimsB1))
    except:
        B=B.flatten()
    C=A@B
    try:
        newdims = jnp.concatenate((jnp.array(Aorig), jnp.array(Borig)))
        return C.reshape(newdims)
    except:
        return C.reshape(Aorig)

In [13]:
import jax
jax.config.update("jax_enable_x64", True)

In [14]:
Ajax= jnp.array(A) 
Bjax= jnp.array(B) 
Cjax= jnp.array(C) 

In [15]:
%%timeit
ABjax = contraction_jax(Ajax,Bjax,[0],[2]);
ABCjax=contraction_jax(ABjax,Cjax,[0,2],[2,1]);

14.7 ms ± 998 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
ABjax = contraction(Ajax,Bjax,[0],[2]);
ABCjax=contraction(ABjax,Cjax,[0,2],[2,1]);

In [17]:
np.isclose(np.array(ABCjax),ABC_True).all()

True

# Scipy sparse matrices.


In [20]:
from scipy.sparse import coo_array

In [23]:
#Asparse=coo_array(A)
Bsparse=coo_array(B)
#Csparse=coo_array(C)

TypeError: expected dimension <= 2 array or matrix