In [None]:
from sympy import S, integrate, trigsimp, expand_trig, I, symbols, sqrt, Symbol
import numpy as np
from IPython.display import display
from functools import lru_cache

In [None]:
expr = S("-q11*q21 - q12*q22 - q13*q23")
expr = S("-q11*q23 + I*q12*q23 + q13*q21 - I*q13*q22")

In [None]:
def get_spherical_substitutions(pp, nn):
    return {
        f"{pp}{nn}1": f"{pp}{nn} * x{nn} * cos(phi{nn})",
        f"{pp}{nn}2": f"{pp}{nn} * x{nn} * sin(phi{nn})",
        f"{pp}{nn}3": f"{pp}{nn} * sqrt(1 - x{nn}**2)",
    }


get_spherical_substitutions("q", 1)

In [None]:
substitutions = {
    **get_spherical_substitutions("q", 1),
    **get_spherical_substitutions("q", 2),
}
substitutions

In [None]:
expr_spherical = expr.subs(substitutions).expand().simplify()
expr_spherical

In [None]:
angular_substitutions = {
    "phi1": "Phi + phi/2",
    "phi2": "Phi + phi/2",
}

In [None]:
expr_spherical_cms = trigsimp(
    expr_spherical.subs(angular_substitutions).expand().simplify()
)
expr_spherical_cms

In [None]:
e = expr_spherical_cms
integrate(e, ("Phi", 0, "2*pi"))

In [None]:
q11, q12, q13, q21, q22, q23 = symbols("q11 q12 q13 q21 q22 q23")
exprs = {
    (0, 0, 0, 0): -q13 * q23
    - (q11 - I * q12) * (q21 + I * q22) / 2
    - (q11 + I * q12) * (q21 - I * q22) / 2,
    (0, 0, 1, -1): -sqrt(2) * q13 * (-q21 + I * q22) / 2
    + sqrt(2) * q23 * (-q11 + I * q12) / 2,
    (0, 0, 1, 0): (q11 - I * q12) * (q21 + I * q22) / 2
    - (q11 + I * q12) * (q21 - I * q22) / 2,
    (0, 0, 1, 1): sqrt(2) * q13 * (q21 + I * q22) / 2
    - sqrt(2) * q23 * (q11 + I * q12) / 2,
    (1, -1, 0, 0): sqrt(2) * q13 * (q21 + I * q22) / 2
    - sqrt(2) * q23 * (q11 + I * q12) / 2,
    (1, 0, 0, 0): -(q11 - I * q12) * (q21 + I * q22) / 2
    + (q11 + I * q12) * (q21 - I * q22) / 2,
    (1, 1, 0, 0): sqrt(2) * q13 * (q21 - I * q22) / 2
    - sqrt(2) * q23 * (q11 - I * q12) / 2,
    (1, -1, 1, -1): q13 * q23,
    (1, -1, 1, 0): -sqrt(2) * q13 * (q21 + I * q22) / 2
    - sqrt(2) * q23 * (q11 + I * q12) / 2,
    (1, -1, 1, 1): (q11 + I * q12) * (q21 + I * q22),
    (1, 0, 1, -1): sqrt(2) * q13 * (-q21 + I * q22) / 2
    + sqrt(2) * q23 * (-q11 + I * q12) / 2,
    (1, 0, 1, 0): -q13 * q23
    + (q11 - I * q12) * (q21 + I * q22) / 2
    + (q11 + I * q12) * (q21 - I * q22) / 2,
    (1, 0, 1, 1): sqrt(2) * q13 * (q21 + I * q22) / 2
    + sqrt(2) * q23 * (q11 + I * q12) / 2,
    (1, 1, 1, -1): (q11 - I * q12) * (q21 - I * q22),
    (1, 1, 1, 0): sqrt(2) * q13 * (q21 - I * q22) / 2
    + sqrt(2) * q23 * (q11 - I * q12) / 2,
    (1, 1, 1, 1): q13 * q23,
}

In [None]:
subs_exprs = {}
for key, expr in exprs.items():
    subs_exprs[key] = expr.subs(substitutions).subs(angular_substitutions).expand()

In [None]:
expr = subs_exprs[(1, -1, 1, 1)]
expr

In [None]:
expr.as_coeff_Mul()

In [None]:
@lru_cache(maxsize=128)
def cached_integrate(*args, **kwargs):
    return integrate(*args, **kwargs)


def integrate_unique_terms(expr, boundaries):
    var, start, end = boundaries

    if not isinstance(var, Symbol):
        var = Symbol(var)

    summands, basis = expr.expand().as_terms()

    cache = {}

    out = 0
    for term, (_, powers, _) in summands:
        kernel = 1
        for ee, pp in zip(basis, powers):
            if var in ee.free_symbols:
                kernel *= ee ** pp

        integrated = cache.setdefault(
            kernel, cached_integrate(kernel, (var, start, end))
        )

        out += term / kernel * integrated

    return out


integrate_unique_terms(expr, ("Phi", 0, "2*pi")) - integrate(expr, ("Phi", 0, "2*pi"))

In [None]:
%timeit integrate_unique_terms(expr, ("Phi", 0, "2*pi"))
%timeit integrate(expr, ("Phi", 0, "2*pi"))

In [None]:
%%timeit
cached_integrate.cache_clear()
for key, expr in subs_exprs.items():
    integrate_unique_terms(expr * S("exp(I*Phi)"), ("Phi", 0, "2*pi"))

In [None]:
%%timeit
for key, expr in subs_exprs.items():
    integrate(expr * S("exp(I*Phi)"), ("Phi", 0, "2*pi"))

In [None]:
s1 = {}
s2 = {}
for key, expr in subs_exprs.items():
    s1[key] = integrate_unique_terms(expr* S("exp(I*Phi)"), ("Phi", 0, "2*pi"))
    s2[key] = integrate(expr* S("exp(I*Phi)"), ("Phi", 0, "2*pi"))

In [None]:
for v1, v2 in zip(s1.values(),  s2.values()):
    try:
        assert (v1 - v2).simplify() == 0
    except Exception as e:
        display(v1)
        display(v2)
        raise e