In [None]:
from itertools import product

from pandas import DataFrame, Series, set_option, read_csv
from pandas.testing import assert_frame_equal

from sympy import S, Symbol, expand_trig, Function
from sympy import exquo, ExactQuotientFailed, ComputationFailed

from numpwd.integrate.analytic import SPHERICAL_BASE_SUBS, ANGLE_BASE_SUBS, integrate
from numpwd.qchannels.cg import get_cg
from numpwd.qchannels.spin import (
    expression_to_matrix,
    pauli_contract_subsystem,
    dict_to_data,
)

set_option("max_colwidth", None)

In [None]:
expr = S("sigma10 * (sigma21 * k1 + sigma22 * k2 + sigma23 * k3)")
expr

In [None]:
mat = expression_to_matrix(expr, pauli_symbol="sigma")
mat

In [None]:
up = S("1/2")
dn = -up
dm_mat = [
    {"ms_o": up, "ms_i": up, "val": S("k3")},
    {"ms_o": dn, "ms_i": dn, "val": -S("k3")},
    {"ms_o": up, "ms_i": dn, "val": S("k1 - I*k2")},
    {"ms_o": dn, "ms_i": up, "val": S("k1 + I*k2")},
]
dm_df = DataFrame(dm_mat).set_index(["ms_o", "ms_i"]).sort_index()
dm_df

In [None]:
mat12 = pauli_contract_subsystem(mat)
mat12

In [None]:
cols = ["s_o", "ms_o", "s_i", "ms_i"]
nuc_df = DataFrame(dict_to_data(mat12, columns=cols))
nuc_df = nuc_df.set_index(cols).sort_index()
nuc_df

In [None]:
def df_outer_product(df1, df2, suffixes=None, reset_index=False):
    tmp1 = df1.reset_index() if reset_index else df1.copy()
    tmp2 = df2.reset_index() if reset_index else df2.copy()

    if suffixes is not None:
        tmp1 = tmp1.rename(columns={key: f"{key}{suffixes[0]}" for key in tmp1.columns})
        tmp2 = tmp2.rename(columns={key: f"{key}{suffixes[1]}" for key in tmp2.columns})

    data = []
    for row1, row2 in product(tmp1.to_dict("records"), tmp2.to_dict("records")):
        data.append({**row1, **row2})

    return DataFrame(data)

In [None]:
df = df_outer_product(nuc_df, dm_df, suffixes=["_nuc", "_dm"], reset_index=True)
df["val"] = df["val_nuc"] * df["val_dm"]
spin_df = df.set_index(
    ["ms_o_dm", "ms_i_dm", "s_o_nuc", "ms_o_nuc", "s_i_nuc", "ms_i_nuc"]
).sort_index()[["val"]]
spin_df.reset_index().query("s_o_nuc == s_i_nuc == 1")

In [None]:
CG = Function("CG")
pwd_fact = CG("s_i_nuc", "ms_i_nuc", "sigma", "m_sigma", "s_o_nuc", "ms_o_nuc")
pwd_fact *= S("exp(I*(m_sigma + ms_o_dm - ms_i_dm)*(Phi-phi/2))")
pwd_fact *= (2 * S("sigma") + 1) / (2 * S("s_o_nuc") + 1)
pwd_fact

In [None]:
momentum_subs = {f"k{n}": f"q_{n}/2 + p_i{n} - p_o{n}" for n in [1, 2, 3]}
momentum_subs

In [None]:
def subs_all(expr):
    return (
        expr.subs(momentum_subs)
        .subs(SPHERICAL_BASE_SUBS)
        .subs(ANGLE_BASE_SUBS)
        .subs({"q_1": 0, "q_2": 0})
        .rewrite("exp")
        .expand()
    )

In [None]:
def op_rank_project(tmp):
    data = dict()
    for row in tmp.to_dict("records"):
        sig_min = abs(row["s_o_nuc"] - row["s_i_nuc"])
        sig_max = abs(row["s_o_nuc"] + row["s_i_nuc"])
        for sigma in range(sig_min, sig_max + 1):
            m_sigma = row["ms_o_nuc"] - row["ms_i_nuc"]
            if abs(m_sigma) > sigma:
                continue
            key = (sigma, m_sigma)
            out = data.get(key, S(0))
            data[key] = out + row["val"] * pwd_fact.subs(
                {**row, "sigma": sigma, "m_sigma": m_sigma}
            ).replace(CG, get_cg)

    for key, val in data.items():
        data[key] = integrate(subs_all(val), ("Phi", 0, "2*pi"))

    out = Series(data, name="val")
    out.index.names = ("sigma", "m_sigma")
    return out


groups = ["ms_o_dm", "ms_i_dm", "s_o_nuc", "s_i_nuc"]
res = spin_df.reset_index().groupby(groups, as_index=True).agg(op_rank_project)

In [None]:
index_cols = ["sigma", "m_sigma", "ms_o_dm", "ms_i_dm", "s_o_nuc", "s_i_nuc"]
non_zero_res = DataFrame(res[res != 0]).reset_index().set_index(index_cols).sort_index()
non_zero_res.head()

In [None]:
alpha = S("q_3 + 2 * p_i * x_i - 2 * p_o * x_o")
beta1 = S("exp(I*phi) * p_i * sqrt(1 - x_i**2) - p_o * sqrt(1 - x_o**2)")
beta2 = S("exp(-I*phi) * p_i * sqrt(1 - x_i**2) - p_o * sqrt(1 - x_o**2)")
omega = S(
    "4*p_i**2 *(1-x_i**2) + 4 * p_o**2 * (1-x_o**2) - 8*p_i * p_o * cos(phi) * sqrt(1-x_i**2)*sqrt(1-x_o**2)"
).rewrite("exp")
alpha, beta1, beta2

In [None]:
quotients = {
    S("a**2"): alpha ** 2,
    S("a*b_1"): alpha * beta1,
    S("a*b_2"): alpha * beta2,
    S("b_1**2"): beta1 ** 2,
    S("b_2**2"): beta2 ** 2,
    S("e"): omega,
}

In [None]:
def decompose(ee):
    fact = None
    mat = None
    for k, q in quotients.items():

        try:
            fact = exquo(ee, q)
            mat = k
            break
        except (ExactQuotientFailed, ComputationFailed):
            pass

    out = Series([fact, mat], index=["fact", "mat"])
    return out


decompose(expr)

In [None]:
decomposition = non_zero_res.val.apply(decompose)
decomposition

In [None]:
legacy = (
    read_csv("input-op-32.csv")
    .rename(
        columns={
            "m_chi_p": "ms_o_dm",
            "m_chi": "ms_i_dm",
            "s_p": "s_o_nuc",
            "s": "s_i_nuc",
        }
    )
    .drop(columns=["O", "m_chi_x2", "m_chi_p_x2"])
)
legacy["ms_o_dm"] = legacy["ms_o_dm"].apply(S)
legacy["ms_i_dm"] = legacy["ms_i_dm"].apply(S)
legacy = (
    legacy.set_index(decomposition.index.names)
    .sort_index()
    .applymap(lambda el: S(el.replace("Sqrt", "sqrt")))
)
legacy.head()

In [None]:
assert_frame_equal(legacy, decomposition)