In [1]:
from e3nn import o3
from rtp import ReducedTensorProducts

from e3nn_jax._symmetric_powers import *

import multiprocessing as mp
import time

In [2]:
rtp = ReducedTensorProducts('ij=ji', i='1e', j='1e')

In [3]:
rtp.irreps_out

1x0e+1x2e

In [12]:
o3.Irreps('1x0e + 1x2e')

1x0e+1x2e

In [2]:
def symmetric_powers_mp(l, n, lmax):
    r"""
    Returns the symmetric powers of the Wigner 3j symbol
    Args:
        l (int): the order of the indices
        n (int): the rank of the tensor (number of indices)
        lmax (int): the maximum order in the output
    Returns:
        dict of the form {l_out: list of sympy arrays}: each array is of shape ``(2l_out+1, 2l+1, ..., 2l+1)``
    """
    assert n > 0
    assert l <= lmax

    if n == 1:
        return {l: [sympy.Array(sympy.eye(2 * l + 1))]}

    res = defaultdict(lambda: [])

    if n % 2 == 0:
        sub = symmetric_powers_mp(l, n // 2, lmax)

        for l1 in sub.keys(): # the orders that are present in this representation
            for l2 in sub.keys():
                for lout in range(abs(l1 - l2), min(lmax, l1 + l2) + 1):
                    for a in sub[l1]:
                        for b in sub[l2]:
                            res[lout].append(product_lll(lout, a, b))

    else:
        def proc0_f(q):
            q.put(symmetric_powers_mp(l, n // 2, lmax))
        def proc1_f(q):
            q.put(symmetric_powers_mp(l, n // 2 + 1, lmax))
        qout = mp.Queue()
        proc0 = mp.Process(target=proc0_f, args=(qout,))
        proc1 = mp.Process(target=proc1_f, args=(qout,))
        proc0.start()
        proc1.start()
        proc0.join()
        proc1.join()
        sub1 = qout.get()
        sub2 = qout.get()

        for l1 in sub1.keys():
            for l2 in sub2.keys():
                for lout in range(abs(l1 - l2), min(lmax, l1 + l2) + 1):
                    for a in sub1[l1]:
                        for b in sub2[l2]:
                            res[lout].append(product_lll(lout, a, b))

    res = {l: solve_symmetric(z) for l, z in res.items()}
    res = {l: orthonormalize(z)[0] for l, z in res.items()}
    res = {l: z for l, z in res.items() if len(z) > 0}
    res = {l: [sympy.simplify(x) for x in z] for l, z in res.items()}
    return res

In [3]:
def symmetric_powers_s(l, n, lmax):
    r"""
    Returns the symmetric powers of the Wigner 3j symbol
    Args:
        l (int): the order of the indices
        n (int): the rank of the tensor (number of indices)
        lmax (int): the maximum order in the output
    Returns:
        dict of the form {l_out: list of sympy arrays}: each array is of shape ``(2l_out+1, 2l+1, ..., 2l+1)``
    """
    assert n > 0
    assert l <= lmax

    if n == 1:
        return {l: [sympy.Array(sympy.eye(2 * l + 1))]}

    res = defaultdict(lambda: [])

    if n % 2 == 0:
        sub = symmetric_powers_s(l, n // 2, lmax)

        for l1 in sub.keys(): # the orders that are present in this representation
            for l2 in sub.keys():
                for lout in range(abs(l1 - l2), min(lmax, l1 + l2) + 1):
                    for a in sub[l1]:
                        for b in sub[l2]:
                            res[lout].append(product_lll(lout, a, b))

    else:
        sub1 = symmetric_powers_s(l, n // 2, lmax)
        sub2 = symmetric_powers_s(l, n // 2 + 1, lmax)

        for l1 in sub1.keys():
            for l2 in sub2.keys():
                for lout in range(abs(l1 - l2), min(lmax, l1 + l2) + 1):
                    for a in sub1[l1]:
                        for b in sub2[l2]:
                            res[lout].append(product_lll(lout, a, b))

    res = {l: solve_symmetric(z) for l, z in res.items()}
    res = {l: orthonormalize(z)[0] for l, z in res.items()}
    res = {l: z for l, z in res.items() if len(z) > 0}
    res = {l: [sympy.simplify(x) for x in z] for l, z in res.items()}
    return res

In [7]:
def btime(N, func, *args):
    T = 0
    for _ in range(N):
        start = time.time()
        func(*args)
        T += time.time() - start
    return T / N

In [8]:
btime(10, symmetric_powers_s, 1, 4, 3)

2.9206363201141357

In [9]:
btime(10, symmetric_powers_mp, 1, 4, 3)

2.6179674863815308