In [1]:
import numpy as np
from scipy.stats import unitary_group
from functools import reduce

In [2]:
def kron_mv(v, *matrices):
    m = reduce(np.kron, matrices)
    return m.dot(v)

In [3]:
num_sites = 6
local_dim = 4

matrices = [unitary_group.rvs(local_dim) for _ in range(num_sites)]

v = np.random.rand(local_dim ** num_sites) + 1j*np.random.rand(local_dim ** num_sites)
v /= np.linalg.norm(v)

In [4]:
def kron_mv_low_mem(x, *matrices):
    n = [m.shape[0] for m in matrices]
    l = np.prod(n)
    r = 1
    V = x.astype(complex)
    for s in range(len(n))[::-1]:
        l //= n[s]
        m = matrices[s]
        for k in range(l):
            for i in range(r):
                slc = slice(
                    k*n[s]*r + i,
                    (k+1)*n[s]*r + i,
                    r,
                )
                U = V[slc]
                V[slc] = np.dot(m, U)
                
        r *= n[s]
    
    return V

In [5]:
%timeit kron_mv(v, *matrices)

567 ms ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%timeit kron_mv_low_mem(v, *matrices)

10.4 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
np.allclose(
    kron_mv(v, *matrices), 
    kron_mv_low_mem(v, *matrices)
)

True