In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

````{margin}
```{spec} Polarization sensitivity
:id: TR-017
:status: WIP
:tags: physics
```
````

# Polarization sensitivity

<!-- cspell:ignore mmikhasenko msigma nanmax nanmean nanstd nansum Remco -->

:::{epigraph}

Mikhail Mikhasenko [@mmikhasenko](https://github.com/mmikhasenko), Remco de Boer [@redeboer](https://github.com/redeboer)

:::

In [None]:
%pip install -q ampform==0.14.0 qrules==0.9.7 sympy==1.10.1 tensorwaves[jax]==0.4.5

This report is an attempt to formulate [this report](https://www.overleaf.com/7229968911cjshysdbfjtj) [behind login] on polarization sensitivity in $\Lambda_c \to p\pi K$ with [SymPy](https://docs.sympy.org) and [TensorWaves](https://tensorwaves.rtfd.io).

In [None]:
from __future__ import annotations

import itertools
import logging
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
from ampform.sympy import (
    PoolSum,
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
    make_commutative,
)
from attrs import frozen
from IPython.display import HTML, Image, Math, display
from ipywidgets import HBox, HTMLMath, Tab, VBox, interactive_output
from matplotlib import cm
from matplotlib.colors import LogNorm
from qrules.particle import Particle
from symplot import create_slider
from sympy.core.symbol import Str
from sympy.physics.matrices import msigma
from sympy.physics.quantum.spin import Rotation as Wigner
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import (
    create_function,
    create_parametrized_function,
)

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

PDG = qrules.load_pdg()


def display_definitions(definitions: dict[sp.Symbol, sp.Expr]) -> None:
    latex = R"\begin{array}{rcl}" + "\n"
    for symbol, expr in definitions.items():
        symbol = sp.sympify(symbol)
        expr = sp.sympify(expr)
        lhs = sp.latex(symbol)
        rhs = sp.latex(expr)
        latex += Rf"  {lhs} & = & {rhs} \\" + "\n"
    latex += R"\end{array}"
    display(Math(latex))


def display_doit(
    expr: UnevaluatedExpression, deep=False, terms_per_line: int = 10
) -> None:
    latex = sp.multiline_latex(
        lhs=expr,
        rhs=expr.doit(deep=deep),
        terms_per_line=terms_per_line,
        environment="eqnarray",
    )
    display(Math(latex))


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex

## Amplitude model

Naming convention: $\Lambda_c(\mathbf{0}) \to p(\mathbf{1}) \pi(\mathbf{2}) K(\mathbf{3})$
- **Chain 1**: $K^{**} \to \pi K(23)$
- **Chain 2**: $\Lambda^{**} \to pK(31)$
- **Chain 3**: $\Delta^{**} \to p\pi(12)$

'Chain 0' is the sum of the three chains.

In [None]:
Λc = PDG["Lambda(c)+"]
p = PDG["p"]
K = PDG["K-"]
π = PDG["pi+"]
decay_products = {
    1: (π, K),
    2: (p, K),
    3: (p, π),
}
siblings = {
    1: p,
    2: π,
    3: K,
}
chain_ids = {
    1: "K^{**}",
    2: R"\Lambda^{**}",
    3: R"\Delta^{**}",
}

Resonance choices and their $LS$-couplings are as follows:

In [None]:
resonance_names = {
    1: ["K*(892)0"],
    2: ["Lambda(1520)", "Lambda(1670)"],
    3: ["Delta(1232)++"],
}
resonances = {
    chain_id: [PDG[name] for name in names]
    for chain_id, names in resonance_names.items()
}

In [None]:
@frozen
class Resonance:
    particle: Particle
    l_R: int
    l_Λc: int

    @staticmethod
    def generate_ls(particle: Particle, chain_id: int) -> Resonance:
        LS_prod = generate_ls(Λc, particle, siblings[chain_id], strong=False)
        LS_prod = [L for L, S in LS_prod]
        LS_dec = generate_ls(particle, *decay_products[chain_id])
        LS_dec = [L for L, S in LS_dec]
        return Resonance(particle, l_R=min(LS_dec), l_Λc=min(LS_prod))


def generate_ls(
    parent: Particle,
    child1: Particle,
    child2: Particle,
    strong: bool = True,
    max_L: int = 3,
):
    s1 = child1.spin
    s2 = child2.spin
    s_values = arange(abs(s1 - s2), s1 + s2)
    LS_values = set()
    for S in s_values:
        for L in arange(0, max_L):
            if not abs(L - S) <= parent.spin <= L + S:
                continue
            η0, η1, η2 = [
                int(parent.parity),
                int(child1.parity),
                int(child2.parity),
            ]
            if strong and η0 != η1 * η2 * (-1) ** L:
                continue
            LS_values.add((L, S))
    return sorted(LS_values)


def arange(x1, x2):
    spin_range = np.arange(float(x1), +float(x2) + 0.5)
    return list(map(sp.Rational, spin_range))


resonance_choices = {
    chain_id: [
        Resonance.generate_ls(particle, chain_id) for particle in particles
    ]
    for chain_id, particles in resonances.items()
}


def jp(particle: Particle):
    p = "+" if particle.parity > 0 else "-"
    j = sp.Rational(particle.spin)
    return Rf"\({j}^{p}\)"


def create_html_table_row(*items, typ="td"):
    items = map(lambda i: f"<{typ}>{i}</{typ}>", items)
    return "<tr>" + "".join(items) + "</tr>\n"


column_names = [
    "resonance",
    R"\(j^P\)",
    R"\(m\) (MeV)",
    R"\(\Gamma_0\) (MeV)",
    R"\(l_R\)",
    R"\(l_{\Lambda_c}^\mathrm{min}\)",
]
src = "<table>\n"
src += create_html_table_row(*column_names, typ="th")
for chain_id, resonance_list in resonance_choices.items():
    child1, child2 = decay_products[chain_id]
    for resonance in resonance_list:
        src += create_html_table_row(
            Rf"\({resonance.particle.latex} \to"
            Rf" {child1.latex} {child2.latex}\)",
            jp(resonance.particle),
            int(1e3 * resonance.particle.mass),
            int(1e3 * resonance.particle.width),
            resonance.l_R,
            resonance.l_Λc,
        )
src += "</table>\n"
HTML(src)

### Aligned amplitude

In [None]:
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
ζ_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
ζ_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
ζ_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
ζ_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
ζ_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_11)
        * Wigner.d(half, _λ, λ_p, ζ_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _λ, λ_p, ζ_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _λ, λ_p, ζ_1_31),
        (_λ, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)

### Dynamics

In [None]:
@make_commutative
@implement_doit_method
class BlattWeisskopf(UnevaluatedExpression):
    def __new__(cls, z, L, **hints):
        return create_expression(cls, z, L, **hints)

    def evaluate(self):
        z, L = self.args
        cases = {
            0: 1,
            1: 1 / (1 + z**2),
            2: 1 / (9 + 3 * z**2 + z**4),
        }
        return sp.Piecewise(
            *[
                (sp.sqrt(expr), sp.Eq(L, l_val))
                for l_val, expr in cases.items()
            ]
        )

    def _latex(self, printer, *args):
        z, L = map(printer._print, self.args)
        return Rf"F_{{{L}}}\left({z}\right)"


z = sp.Symbol("z", positive=True)
L = sp.Symbol("L", integer=True, nonnegative=True)
latex = sp.multiline_latex(BlattWeisskopf(z, L), BlattWeisskopf(z, L).doit())
Math(latex)

In [None]:
@make_commutative
@implement_doit_method
class Källén(UnevaluatedExpression):
    def __new__(cls, x, y, z, **hints):
        return create_expression(cls, x, y, z, **hints)

    def evaluate(self) -> sp.Expr:
        x, y, z = self.args
        return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x

    def _latex(self, printer, *args):
        x, y, z = map(printer._print, self.args)
        return Rf"\lambda\left({x}, {y}, {z}\right)"


x, y, z = sp.symbols("x:z")
display_doit(Källén(x, y, z))

In [None]:
@make_commutative
@implement_doit_method
class P(UnevaluatedExpression):
    def __new__(cls, s, mi, mj, **hints):
        return create_expression(cls, s, mi, mj, **hints)

    def evaluate(self):
        s, mi, mj = self.args
        return sp.sqrt(Källén(s, mi**2, mj**2)) / (2 * sp.sqrt(s))

    def _latex(self, printer, *args):
        s = printer._print(self.args[0])
        return Rf"p_{{{s}}}"


@make_commutative
@implement_doit_method
class Q(UnevaluatedExpression):
    def __new__(cls, s, m0, mk, **hints):
        return create_expression(cls, s, m0, mk, **hints)

    def evaluate(self):
        s, m0, mk = self.args
        return sp.sqrt(Källén(s, m0**2, mk**2)) / (2 * m0)  # <-- not s!

    def _latex(self, printer, *args):
        s = printer._print(self.args[0])
        return Rf"q_{{{s}}}"


s, m0, mi, mj, mk = sp.symbols("s m0 m_i:k", nonnegative=True)
display_doit(P(s, mi, mj))
display_doit(Q(s, m0, mk))

In [None]:
R = sp.Symbol("R")
parameter_defaults = {
    R: 5,  # GeV^{-1} (length factor)
}


@make_commutative
@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
    def __new__(cls, s, m0, Γ0, m1, m2, L, R):
        return create_expression(cls, s, m0, Γ0, m1, m2, L, R)

    def evaluate(self):
        s, m0, Γ0, m1, m2, L, R = self.args
        p = P(s, m1, m2)
        p0 = P(m0**2, m1, m2)
        ff = BlattWeisskopf(p * R, L) ** 2
        ff0 = BlattWeisskopf(p0 * R, L) ** 2
        return sp.Mul(
            Γ0,
            (p / p0) ** (2 * L + 1),
            m0 / sp.sqrt(s),
            ff / ff0,
            evaluate=False,
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\Gamma\left({s}\right)"


l_R = sp.Symbol("l_R", integer=True, positive=True)
m, Γ0, m1, m2 = sp.symbols("m Γ0 m1 m2", nonnegative=True)
display_doit(EnergyDependentWidth(s, m, Γ0, m1, m2, l_R, R))

In [None]:
@make_commutative
@implement_doit_method
class RelativisticBreitWigner(UnevaluatedExpression):
    def __new__(cls, s, m0, Γ0, m1, m2, l_R, l_Λc, R):
        return create_expression(cls, s, m0, Γ0, m1, m2, l_R, l_Λc, R)

    def evaluate(self):
        s, m0, Γ0, m1, m2, l_R, l_Λc, R = self.args
        q = Q(s, m1, m2)
        q0 = Q(m0**2, m1, m2)
        p = P(s, m1, m2)
        p0 = P(m0**2, m1, m2)
        width = EnergyDependentWidth(s, m0, Γ0, m1, m2, l_R, R)
        return sp.Mul(
            (q / q0) ** l_Λc,
            BlattWeisskopf(q * R, l_Λc) / BlattWeisskopf(q0 * R, l_Λc),
            1 / (m0**2 - s - sp.I * m0 * width),
            (p / p0) ** l_R,
            BlattWeisskopf(p * R, l_R) / BlattWeisskopf(p0 * R, l_R),
            evaluate=False,
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\mathcal{{R}}\left({s}\right)"


l_Λc = sp.Symbol(R"l_{\Lambda_c}", integer=True, positive=True)
display_doit(RelativisticBreitWigner(s, m, Γ0, m1, m2, l_R, l_Λc, R))

### Decay chain amplitudes

In [None]:
def formulate_chain_amplitude(chain_id: int, λ_Λc, λ_p):
    resonances = resonance_choices[chain_id]
    if chain_id == 1:
        return formulate_K_amplitude(λ_Λc, λ_p, resonances)
    if chain_id == 2:
        return formulate_Λ_amplitude(λ_Λc, λ_p, resonances)
    if chain_id == 3:
        return formulate_Δ_amplitude(λ_Λc, λ_p, resonances)
    raise NotImplementedError


H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)

σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols(R"m_p m_pi m_K", nonnegative=True)


def formulate_K_amplitude(λ_Λc, λ_p, resonances: list[Resonance]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ - λ_p)
                * H_prod[stringify(res), τ, λ_p]
                * formulate_dynamics(res, σ1, m2, m3)
                * (-1) ** (half - λ_p)
                * Wigner.d(sp.Rational(res.particle.spin), τ, 0, θ23)
                * H_dec[stringify(res), 0, 0],
                (τ, create_spin_range(res.particle.spin)),
            )
            for res in resonances
        ]
    )


def formulate_Λ_amplitude(λ_Λc, λ_p, resonances: list[Resonance]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(res), τ, 0]
                * formulate_dynamics(res, σ2, m1, m3)
                * Wigner.d(sp.Rational(res.particle.spin), τ, -λ_p, θ31)
                * H_dec[stringify(res), 0, λ_p]
                * (-1) ** (half - λ_p),
                (τ, create_spin_range(res.particle.spin)),
            )
            for res in resonances
        ]
    )


def formulate_Δ_amplitude(λ_Λc, λ_p, resonances: list[Resonance]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(res), τ, 0]
                * formulate_dynamics(res, σ3, m1, m2)
                * Wigner.d(sp.Rational(res.particle.spin), τ, λ_p, θ12)
                * H_dec[stringify(res), λ_p, 0],
                (τ, create_spin_range(res.particle.spin)),
            )
            for res in resonances
        ]
    )


def formulate_dynamics(decay: Resonance, s, m1, m2):
    l_R = sp.Rational(decay.l_R)
    l_Λc = sp.Rational(decay.l_Λc)
    mass = sp.Symbol(f"m_{{{decay.particle.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.particle.latex}}}")
    parameter_defaults[mass] = decay.particle.mass
    parameter_defaults[width] = decay.particle.width
    return RelativisticBreitWigner(s, mass, width, m1, m2, l_R, l_Λc, R)


def stringify(particle: Particle | Resonance) -> Str:
    if isinstance(particle, Resonance):
        particle = particle.particle
    return Str(particle.latex)


def create_spin_range(j):
    return arange(-j, +j)


display(
    formulate_chain_amplitude(1, ν, λ),
    formulate_chain_amplitude(2, ν, λ),
    formulate_chain_amplitude(3, ν, λ),
)

### Angle definitions

Following relations apply:

$$
\begin{eqnarray}
  \zeta^0_{1(1)} &=& \hat{\theta}_{1(1)}^{0} = 0 \\
  \zeta^0_{2(1)} &=& \hat{\theta}_{2(1)} = -\hat{\theta}_{1(2)} \\
  \zeta^0_{3(1)} &=& \hat{\theta}_{3(1)} \\
  \zeta^1_{1(1)} &=& 0 \\
  \zeta^1_{3(1)} &=& -\zeta^1_{1(3)} \\
\end{eqnarray}
$$

The remaining angles $\theta_{12}, \theta_{23}, \theta_{13}$ and $\hat\theta_{1(2)}, \hat\theta_{3(1)}, \zeta^1_{1(3)}$ can be expressed in terms of Mandelstam variables $\sigma_1, \sigma_2, \sigma_3$ using {cite}`mikhasenkoDalitzplotDecompositionThreebody2020`, Appendix A:

In [None]:
m0 = sp.Symbol(R"m_{\Lambda_c}", nonnegative=True)
angles = {
    θ12: sp.acos(
        (
            2 * σ3 * (σ2 - m3**2 - m1**2)
            - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m3**2, σ3))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
    θ23: sp.acos(
        (
            2 * σ1 * (σ3 - m1**2 - m2**2)
            - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ1, m2**2, m3**2))
        )
    ),
    θ31: sp.acos(
        (
            2 * σ2 * (σ1 - m2**2 - m3**2)
            - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(σ2, m3**2, m1**2))
        )
    ),
    ζ_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    ζ_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(m0**2, σ1, m1**2))
        )
    ),
    ζ_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(m0**2, σ3, m3**2))
        )
    ),
    ζ_1_11: sp.S.Zero,
    ζ_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ2, m1**2, m3**2))
        )
    ),
    ζ_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
}

display_definitions(angles)

where $m_0$ is the mass of the initial state $\Lambda_c$ and $m_1, m_2, m_3$ are the masses of $p, \pi, K$, respectively:

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_definitions(masses)

### Helicity coupling values

In [None]:
dec_couplings = {}
for res in resonance_choices[1]:
    i = stringify(res)
    dec_couplings[H_dec[i, 0, 0]] = 1
for res in resonance_choices[2]:
    i = stringify(res.particle)
    dec_couplings[H_dec[i, 0, half]] = 1
    dec_couplings[H_dec[i, 0, -half]] = (
        int(res.particle.parity)
        * int(K.parity)
        * int(p.parity)
        * (-1) ** (res.particle.spin - K.spin - p.spin)
    )
for res in resonance_choices[3]:
    i = stringify(res.particle)
    dec_couplings[H_dec[i, half, 0]] = 1
    dec_couplings[H_dec[i, -half, 0]] = (
        int(res.particle.parity)
        * int(p.parity)
        * int(π.parity)
        * (-1) ** (res.particle.spin - p.spin - π.spin)
    )
parameter_defaults.update(dec_couplings)
display_definitions(dec_couplings)

In [None]:
prod_couplings = {}
#
# chain 23:
prod_couplings[H_prod[Str("K^{*}(892)^{0}"), 0, -half]] = 1
prod_couplings[H_prod[Str("K^{*}(892)^{0}"), -1, -half]] = 1 - 1j
prod_couplings[H_prod[Str("K^{*}(892)^{0}"), +1, +half]] = -3 - 3j
prod_couplings[H_prod[Str("K^{*}(892)^{0}"), 0, +half]] = -1 - 4j
#
# chain 31:
prod_couplings[H_prod[Str("\\Lambda(1520)"), +half, 0]] = 1.5
prod_couplings[H_prod[Str("\\Lambda(1520)"), -half, 0]] = 0.3
#
prod_couplings[H_prod[Str("\\Lambda(1670)"), +half, 0]] = -0.5 + 1j
prod_couplings[H_prod[Str("\\Lambda(1670)"), -half, 0]] = -0.3-0.1j
#
# chain 12:
prod_couplings[H_prod[Str("\\Delta(1232)^{++}"), +half, 0]] = -13 + 5j
prod_couplings[H_prod[Str("\\Delta(1232)^{++}"), -half, 0]] = -7 + 3j
#
display_definitions(prod_couplings)
couplings = dict(dec_couplings)
couplings.update(prod_couplings)
parameter_defaults.update(prod_couplings)

### Intensity expression

Incoherent sum of the amplitudes defined by {ref}`report/017:Aligned amplitude`:

In [None]:
def formulate_intensity(amplitude_builder):
    return PoolSum(
        sp.Abs(amplitude_builder(λ, ν)) ** 2,
        (λ, [-half, +half]),
        (ν, [-half, +half]),
    )


intensity_expressions = {
    0: formulate_intensity(formulate_aligned_amplitude),
    1: formulate_intensity(partial(formulate_chain_amplitude, 1)),
    2: formulate_intensity(partial(formulate_chain_amplitude, 2)),
    3: formulate_intensity(partial(formulate_chain_amplitude, 3)),
}
intensity_expressions[0]

Remaining {attr}`~sympy.core.basic.Basic.free_symbols` are indeed the specific amplitudes as defined by {ref}`report/017:Decay chain amplitudes`:

The specific amplitudes from {ref}`report/017:Decay chain amplitudes` need to be formulated for each value of $\nu, \lambda$, so that they can be substituted in the top expression:

In [None]:
A = {1: A_K, 2: A_Λ, 3: A_Δ}
amp_definitions = {}
for chain_id in chain_ids:
    for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
        symbol = A[chain_id][Λc_heli, p_heli]
        expr = formulate_chain_amplitude(chain_id, ν, λ)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_definitions(amp_definitions)

In [None]:
substituted_intensity_expressions = {}
for chain_id, expr in intensity_expressions.items():
    expr = expr.doit().xreplace(amp_definitions).doit()
    expr = expr.xreplace(angles).doit().xreplace(masses)
    substituted_intensity_expressions[chain_id] = expr
    expr = expr.xreplace(parameter_defaults)
    if chain_id == 0:
        assert expr.free_symbols == {σ1, σ2, σ3}
    else:
        assert expr.free_symbols < {σ1, σ2, σ3}

### Polarization

$$
\vec\alpha(m_{K\pi},m_{pK}) =  \sum_{\lambda,\nu,\nu'} A^{*}_{\nu,\lambda}\vec\sigma_{\nu,\nu'}  A_{\nu',\lambda} \,\big / \sum_{\lambda,\nu} \left|A_{\nu,\lambda}\right|^2
$$ (polarization-sensitivity)

#### Total polarization sensitivity

In [None]:
def to_index(helicity):
    """Symbolic conversion of half-value helicities to Pauli matrix indices."""
    # https://github.com/ComPWA/compwa-org/pull/129#issuecomment-1096599896
    return sp.Piecewise(
        (1, sp.LessThan(helicity, 0)),
        (0, True),
    )


ν_prime = sp.Symbol(R"\nu^{\prime}")
total_polarization = sp.Array(
    PoolSum(
        formulate_aligned_amplitude(ν, λ).conjugate()
        * msigma(i)[to_index(ν), to_index(ν_prime)]
        * formulate_aligned_amplitude(ν_prime, λ),
        (λ, [-half, +half]),
        (ν, [-half, +half]),
        (ν_prime, [-half, +half]),
    )
    / intensity_expressions[0]
    for i in [1, 2, 3]
)

#### Polarization sensitivity per chain

In [None]:
polarization_expressions = {0: total_polarization}
for chain_id in chain_ids:
    polarization_expressions[chain_id] = sp.Array(
        PoolSum(
            formulate_chain_amplitude(chain_id, ν, λ).conjugate()
            * msigma(i)[to_index(ν), to_index(ν_prime)]
            * formulate_chain_amplitude(chain_id, ν_prime, λ),
            (λ, [-half, +half]),
            (ν, [-half, +half]),
            (ν_prime, [-half, +half]),
        )
        / intensity_expressions[chain_id]
        for i in [1, 2, 3]
    )

In [None]:
substituted_polarization_expressions = {}
for chain_id, expr in polarization_expressions.items():
    expr = expr.doit().xreplace(amp_definitions).doit()
    expr = expr.xreplace(angles).doit().xreplace(masses)
    substituted_polarization_expressions[chain_id] = expr
    expr = expr.xreplace(parameter_defaults)
    if chain_id == 0:
        assert expr.free_symbols == {σ1, σ2, σ3}
    else:
        assert expr.free_symbols < {σ1, σ2, σ3}

## Computations with TensorWaves


### Conversion to computational backend

The full [expression tree](https://docs.sympy.org/latest/tutorial/manipulation.html) can be converted to a computational, _parametrized_ function as follows. Note that identify all coupling symbols are interpreted as parameters. The remaining symbols (the angles) become arguments to the function.

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol.name.startswith("m_")
    or symbol.name.startswith(R"\Gamma_")
    or symbol in couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in free_parameters
}

In [None]:
intensity_functions = {
    chain_id: create_parametrized_function(
        expr.subs(fixed_parameters),
        parameters=free_parameters,
        backend="jax",
    )
    for chain_id, expr in substituted_intensity_expressions.items()
}

In [None]:
polarization_functions = {
    chain_id: [
        create_parametrized_function(
            expr[i].subs(fixed_parameters),
            parameters=free_parameters,
            backend="jax",
        )
        for i in range(3)
    ]
    for chain_id, expr in substituted_polarization_expressions.items()
}

### Phase space

In [None]:
computed_σ3 = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
compute_third_mandelstam = create_function(
    computed_σ3.subs(masses), backend="jax"
)
display_definitions({σ3: computed_σ3})

In [None]:
def kibble_function(σ1, σ2):
    return Källén(
        Källén(σ2, m2**2, m0**2),
        Källén(σ3, m3**2, m0**2),
        Källén(σ1, m1**2, m0**2),
    )


def is_within_phsp(σ1, σ2, non_phsp_value=sp.nan):
    return sp.Piecewise(
        (1, sp.LessThan(kibble_function(σ1, σ2), 0)),
        (non_phsp_value, True),
    )


is_within_phsp(σ1, σ2)

In [None]:
in_phsp_expr = is_within_phsp(σ1, σ2).subs(σ3, computed_σ3).subs(masses).doit()
in_phsp_expr.free_symbols

In [None]:
resolution = 200
m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2
X, Y = np.meshgrid(
    np.linspace(σ1_min, σ1_max, num=resolution),
    np.linspace(σ2_min, σ2_max, num=resolution),
)
Z = compute_third_mandelstam.function(X, Y)
σ_arrays = {"sigma1": X, "sigma2": Y, "sigma3": Z}

in_phsp = create_function(in_phsp_expr, backend="numpy")
phsp = in_phsp(σ_arrays)

Values for the angles will be computed form the Mandelstam values with a data transformer for the symbolic angle definitions:

In [None]:
kinematic_variables = {
    symbol: expression.doit().subs(masses).subs(fixed_parameters)
    for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(
    kinematic_variables, backend="jax"
)
kinematic_arrays = transformer(σ_arrays)

### Intensity distribution

Finally, all intensities can be computed as follows:

In [None]:
intensities = {
    chain_id: func(kinematic_arrays)
    for chain_id, func in intensity_functions.items()
}

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

fig, ax = plt.subplots(
    figsize=(10, 8),
    tight_layout=True,
)
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

mesh = ax.pcolormesh(
    X, Y, phsp * intensities[0], cmap=cm.coolwarm, norm=LogNorm()
)
fig.colorbar(mesh, ax=ax)
plt.show()

### Fit fractions

The total decay rate for $\Lambda_c^+ \to pK\pi$ can be broken into fractions that correspond to the different decay chains and interference terms. The total rate is computed as an integral of the intensity over decay kinematics:

$$
\begin{align}
  I_\text{tot}(\{\mathcal{H}\}) = \int d m_{pK}^2 d m_{K\pi}^2\,
  I_0(m_{pK}, m_{K\pi} | \{\mathcal{H}\})
  \approx \frac{\Phi_0}{N_\text{MC}} \sum_{e=1}^{N_\text{MC}}\,\,I_0(m_{pK,e}, m_{K\pi,e} | \{\mathcal{H}\})\,,
\end{align}
$$

where $\Phi_0$ is an (irrelevant) constant equal to the flat phase-space integral, $(m_{pK,e}, m_{K\pi,e})$ is a vector of the kinematic variables for the $e$-th point in the MC sample.

The conditional argument $\{\mathcal{H}\}$ indicates dependence of the rate on the value of the couplings. The individual fractions are found by computing the total rate for a subset of couplings set to zero,

$$
\begin{align}
  I_\text{tot}^{K} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to\Delta^{**} K}, \mathcal{H}^{\Lambda_c^+\to\Lambda^{**} \pi} = 0\}\right)\,,\\
  I_\text{tot}^{\Delta} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to K^{**} p}, \mathcal{H}^{\Lambda_c^+\to\Lambda^{**} \pi} = 0\}\right)\,,\\
  I_\text{tot}^{\Lambda} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to\Delta^{**} K}, \mathcal{H}^{\Lambda_c^+\to K^{**} p} = 0\}\right)\,,\\
  I_\text{tot}^{K/\Lambda} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to\Delta^{**} K} = 0\}\right) -  I_\text{tot}^{K} - I_\text{tot}^{\Lambda}\,,\\
  & \dots\,,
\end{align}
$$

where the terms with a single chain index are the rate of the decay chain. The sum of all fractions should give the total rate:

$$
\begin{align}
  I_\text{tot}\left(\{\mathcal{H}\}\right)
  = \sum_{R} I_\text{tot}^{R} +  \sum_{R < R'} I_\text{tot}^{R/R'}
\end{align}
$$

In [None]:
def sub_intensity(data, non_zero_couplings: list[str]):
    func = intensity_functions[0]
    new_parameters = dict(func.parameters)
    for par_name in new_parameters:
        if not par_name.startswith(R"\mathcal{H}^\mathrm{production}"):
            continue
        if any(map(lambda s: s in par_name, non_zero_couplings)):
            continue
        new_parameters[par_name] = 0
    old_parameters = dict(func.parameters)
    func.update_parameters(new_parameters)
    intensities = func(data)
    func.update_parameters(old_parameters)
    return integrate_intensity(intensities)


def integrate_intensity(intensities):
    return np.nansum(intensities) / len(intensities)


I_tot = integrate_intensity(intensity_functions[0](kinematic_arrays))
np.testing.assert_allclose(
    I_tot,
    sub_intensity(kinematic_arrays, ["K", R"\Lambda", R"\Delta"]),
)

In [None]:
def interference_intensity(
    data,
    chain1: list[str],
    chain2: list[str],
):
    I_interference = sub_intensity(data, chain1 + chain2)
    I_chain1 = sub_intensity(data, chain1)
    I_chain2 = sub_intensity(data, chain2)
    return I_interference - I_chain1 - I_chain2


I_K = sub_intensity(kinematic_arrays, non_zero_couplings=["K"])
I_Λ = sub_intensity(kinematic_arrays, non_zero_couplings=["Lambda"])
I_Δ = sub_intensity(kinematic_arrays, non_zero_couplings=["Delta"])
I_ΛΔ = interference_intensity(kinematic_arrays, ["Lambda"], ["Delta"])
I_KΔ = interference_intensity(kinematic_arrays, ["K"], ["Delta"])
I_KΛ = interference_intensity(kinematic_arrays, ["K"], ["Lambda"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)

In [None]:
rows = [
    ("K^{**}", f"{I_K/I_tot:.3f}"),
    (R"\Lambda^{**}", f"{I_Λ/I_tot:.3f}"),
    (R"\Delta^{**}", f"{I_Δ/I_tot:.3f}"),
    (R"\Delta/\Lambda", f"{I_ΛΔ/I_tot:.3f}"),
    (R"K/\Delta", f"{I_KΔ/I_tot:.3f}"),
    (R"K/\Lambda", f"{I_KΛ/I_tot:.3f}"),
    (
        R"\mathrm{total}",
        f"{(I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ) /I_tot:.3f}",
    ),
]

latex = R"\begin{array}{crr}" + "\n"
latex += R"& I_\mathrm{sub}\,/\,I \\" + "\n"
for row in rows:
    latex += "  " + " & ".join(row) + R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

### Polarization distributions

In [None]:
polarization_values = {
    chain_id: [func[i](kinematic_arrays) for i in range(3)]
    for chain_id, func in polarization_functions.items()
}
for chain_id in range(4):
    for array in polarization_values[chain_id]:
        assert np.nanmax(array.imag) < 1e-10

{{ run_interactive }}

In [None]:
def render_mean(array):
    mean = f"{np.nanmean(array):.3f}"
    std = f"{np.nanstd(array):.3f}"
    if float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


latex = R"\begin{array}{cccc}" + "\n"
latex += R"& \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
for chain_id, label in chain_ids.items():
    latex += f"  {label} & "
    x, y, z = polarization_values[chain_id]
    latex += " & ".join(map(render_mean, [x.real, y.real, z.real]))
    latex += R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

In [None]:
%matplotlib widget

In [None]:
# Sliders
def set_slider(slider, value):
    slider.min = min(0.0, 2 * value)
    slider.max = max(5.0, 2 * value)
    slider.step = 0.01
    slider.value = value


sliders = {}
for symbol, value in free_parameters.items():
    if symbol.name.startswith(R"\mathcal{H}") and "production" in symbol.name:
        real_slider = create_slider(symbol)
        imag_slider = create_slider(symbol)
        sliders[f"{symbol.name}_real"] = real_slider
        sliders[f"{symbol.name}_imag"] = imag_slider
        value = complex(value)
        set_slider(real_slider, value.real)
        set_slider(imag_slider, value.imag)
        real_slider.description = R"\(\mathrm{Re}\)"
        imag_slider.description = R"\(\mathrm{Im}\)"
    else:
        slider = create_slider(symbol)
        sliders[symbol.name] = slider
        set_slider(slider, value)

σ3_max = (m0_val - m3_val) ** 2
σ3_min = (m1_val + m2_val) ** 2

for name, slider in sliders.items():
    if name.startswith("m_"):
        if "K" in name:
            slider.min = np.sqrt(σ1_min)
            slider.max = np.sqrt(σ1_max)
        elif R"\Lambda" in name:
            slider.min = np.sqrt(σ2_min)
            slider.max = np.sqrt(σ2_max)
        elif R"\Delta" in name:
            slider.min = np.sqrt(σ3_min)
            slider.max = np.sqrt(σ3_max)
    elif name.startswith(R"\Gamma_"):
        slider.min = 0
        slider.max = max(0.5, 2 * slider.value)

latex = {symbol.name: sp.latex(symbol) for symbol in free_parameters}
mass_sliders = [sliders[n] for n in sliders if n.startswith("m_")]
width_sliders = [sliders[n] for n in sliders if n.startswith(R"\Gamma_")]
coupling_sliders = {}
for res_list in resonances.values():
    for res in res_list:
        coupling_sliders[res.name] = (
            [
                s
                for n, s in sliders.items()
                if n.endswith("_real") and res.latex in n
            ],
            [
                s
                for n, s in sliders.items()
                if n.endswith("_imag") and res.latex in n
            ],
            [
                HTMLMath(f"${latex[n[:-5]]}$")
                for n in sliders
                if n.endswith("_real") and res.latex in n
            ],
        )
ui = Tab(
    children=[
        Tab(
            children=[
                VBox([HBox(s) for s in zip(*pair)])
                for pair in coupling_sliders.values()
            ],
            _titles={i: label for i, label in enumerate(coupling_sliders)},
        ),
        VBox([HBox([r, i]) for r, i in zip(mass_sliders, width_sliders)]),
    ],
    _titles=dict(enumerate(["Couplings", "Masses and widths"])),
)


# Visualization
def to_complex_kwargs(**kwargs):
    complex_valued_kwargs = {}
    for key, value in dict(kwargs).items():
        if key.endswith("real"):
            symbol_name = key[:-5]
            imag = kwargs[f"{symbol_name}_imag"]
            complex_valued_kwargs[symbol_name] = complex(value, imag)
        elif key.endswith("imag"):
            continue
        else:
            complex_valued_kwargs[key] = value
    return complex_valued_kwargs


def visualize_visualization() -> None:
    fig, axes = plt.subplots(
        figsize=0.8 * np.array([13, 9]),
        ncols=3,
        nrows=2,
        sharex=True,
        sharey=True,
        gridspec_kw=dict(
            width_ratios=[1, 1, 1.24],
        ),
        tight_layout=True,
    )
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False

    for chain_id in range(2):
        for i in range(3):
            ax = axes[chain_id, i]
            title = Rf"\alpha_{'xyz'[i]}"
            if chain_id == 1:
                title = Rf"{title}\left(K^{{**}}\right)"
            ax.set_title(f"${title}$")
            if ax is axes[-1, i]:
                ax.set_xlabel(s1_label)
            if i == 0:
                ax.set_ylabel(s2_label)

    color_mesh = np.full([2, 3], None)

    def plot3(**kwargs):
        nonlocal color_mesh
        kwargs = to_complex_kwargs(**kwargs)
        for chain_id in range(2):
            for i in range(3):
                func = polarization_functions[chain_id][i]
                func.update_parameters(kwargs)
                z_values = func(kinematic_arrays)
                z_values = np.real(z_values)
                ax = axes[chain_id, i]
                if color_mesh[chain_id, i] is None:
                    color_mesh[chain_id, i] = ax.pcolormesh(
                        X, Y, z_values, cmap=cm.coolwarm
                    )
                    if ax is axes[chain_id, -1]:
                        fig.colorbar(color_mesh[chain_id, i], ax=ax)
                else:
                    color_mesh[chain_id, i].set_array(z_values)
                color_mesh[chain_id, i].set_clim(vmin=-1, vmax=+1)
        fig.canvas.draw()

    output = interactive_output(plot3, controls=sliders)
    display(ui, output)


visualize_visualization()

In [None]:
if STATIC_WEB_PAGE:
    filename = "017-polarization-sensitivity.png"
    plt.savefig(filename)
    display(Image(filename))