In [1]:
import timeit
from numba import jit, prange
import numpy as np
from functools import lru_cache, wraps

In [2]:
def npCacheMap(*args, **kwargs):
    """LRU cache implementation for functions whose FIRST parameter is a numpy array
    forked from: https://gist.github.com/Susensio/61f4fee01150caaac1e10fc5f005eb75"""

    def decorator(function):
        @wraps(function)
        def wrapper(np_array, *args, **kwargs):
            hashable_array = array_to_tuple(np_array)#tuple(map(tuple, np_array))
            return cached_wrapper(hashable_array, *args, **kwargs)

        @lru_cache(*args, **kwargs)
        def cached_wrapper(hashable_array, *args, **kwargs):
            array = np.array(hashable_array)
            return function(array, *args, **kwargs)
        
        def array_to_tuple(np_array):
            """Iterates recursivelly."""
            try:
                return tuple(array_to_tuple(_) for _ in np_array)
            except TypeError:
                return np_array

        # copy lru_cache attributes over too
        wrapper.cache_info = cached_wrapper.cache_info
        wrapper.cache_clear = cached_wrapper.cache_clear
        return wrapper

    return decorator

In [3]:
from joblib import Memory
location = "./cachedir"
memory = Memory(location, verbose=0)

In [4]:
memory.clear(warn=False)

In [5]:
@jit(forceobj=True)
def matmult_l(us):
    result = us[0]
    for u in us[1:]:
        result = result @ u

    return result


@jit(parallel=True)
def matmult_even(us):

    l, dims, _ = us.shape

    double_us = np.zeros((l // 2, dims, dims))

    for i in prange(0, l // 2):
        double_us[i, :, :] = us[i * 2 + 1] @ us[i * 2]

    return double_us

# @npCacheMap()
@jit(forceobj=True)
def matmult_t(us):
    """Recursively multiply the neighbouring gates.
    When the block size gets below the turnover point the linear
    kron_gates_l is used as it is more efficient in this usecase."""
    TURNOVER = 12

    l = len(us)

    if l > TURNOVER:
        if l % 2 == 0:
            return matmult_t(matmult_even(us))
        return us[-1] @ matmult_t( matmult_even(us[:-1, :, :]))

    return matmult_l(us)

In [6]:
# first define some testing data


test_f = lambda tp: np.random.rand(tp, 32, 32)

In [7]:
list(map(lambda a: (matmult_l(a) == matmult_t(a)).all(), [test_f(i) for i in range(1, 10)]))

  result = result @ u


[True, True, True, True, True, True, True, True, True]

In [10]:
verbose = True

for i in range(2,10):
    test = test_f(i)

    tl = timeit.timeit(lambda: matmult_l(test.copy()), number=1000)
    tt = timeit.timeit(lambda: matmult_t(test.copy()), number=1000)

    if verbose:
        print(
            f"""Testing natmult_l for multiplying {i} gates \r
        {tl} """
        )

        print(
            f"""Testing matmult_t for multiplying {i} gates \r
        {tt} """
        )

    if tl < tt:
        print(f"PREFER LINEAR (i={i})")

    if tt < tl:
        print(f"PREFER RECURSIVE (i={i})")

Testing natmult_l for multiplying 2 gates 
        0.0066094999929191545 
Testing matmult_t for multiplying 2 gates 
        0.006867100004456006 
PREFER LINEAR (i=2)
Testing natmult_l for multiplying 3 gates 
        0.011272400006419048 
Testing matmult_t for multiplying 3 gates 
        0.011949399995501153 
PREFER LINEAR (i=3)
Testing natmult_l for multiplying 4 gates 
        0.01409589999821037 
Testing matmult_t for multiplying 4 gates 
        0.014186999993398786 
PREFER LINEAR (i=4)
Testing natmult_l for multiplying 5 gates 
        0.018018900009337813 
Testing matmult_t for multiplying 5 gates 
        0.018075300002237782 
PREFER LINEAR (i=5)
Testing natmult_l for multiplying 6 gates 
        0.02061429999594111 
Testing matmult_t for multiplying 6 gates 
        0.021542699993005954 
PREFER LINEAR (i=6)
Testing natmult_l for multiplying 7 gates 
        0.02367769999545999 
Testing matmult_t for multiplying 7 gates 
        0.023971899994648993 
PREFER LINEAR (i=7)
Testin