In [179]:
from tf_pwa.config_loader import ConfigLoader
from tf_pwa.amp.core import Particle
import numpy as np
import tensorflow as tf
import json
import warnings

In [294]:
config = ConfigLoader("config_a.yml")
config.set_params("final_params.json")

p4 = config.generate_phsp_p(1)
data = config.data.cal_angle(p)

amp_model = config.get_amplitude()

In [288]:
print(amp_model.decay_group.chains[1][1].resonances())
for i in range(len(amp_model.decay_group.chains[0])):
    print(type(amp_model.decay_group.chains[0][i]))

AttributeError: 'HelicityDecay' object has no attribute 'resonances'

In [293]:
dg = amp_model.decay_group
chains = getattr(dg, "decay_chains", None) or getattr(dg, "chains")

def iter_res(decay):
    r = getattr(decay, "res", None) or getattr(decay, "resonances", None)
    if r is None and hasattr(decay, "get_resonances"):
        r = decay.get_resonances()
    return r or []

for chain in chains:
    for decay in chain:
        for res in iter_res(decay):
            # einige Versionen erwarten (data, data_p); nimm beidenfalls data
            amp = res.get_amp(data) if res.get_amp.__code__.co_argcount == 2 else res.get_amp(data, data)
            print(res.name, amp.numpy() if hasattr(amp, "numpy") else amp)

In [295]:
amp_model = config.get_amplitude()
assert hasattr(amp_model, "decay_group"), "amp_model hat kein decay_group"

dg = amp_model.decay_group

# chains vs decay_chains
chains = getattr(dg, "decay_chains", None) or getattr(dg, "chains", None)
print("chains type/len:", type(chains), (len(chains) if chains is not None else None))
assert chains is not None, "Keine chains/decay_chains gefunden"

def iter_steps(chain):
    # Manche Versionen liefern die Schritte direkt iterierbar,
    # andere haben Attribute wie 'decays' oder 'chain'
    if isinstance(chain, (list, tuple)):
        return chain
    for attr in ("decays", "chain", "steps"):
        if hasattr(chain, attr):
            return getattr(chain, attr)
    return []

def iter_resonances(step):
    # Versuche typische Namen/Getter; Liste oder Dict möglich
    for attr in ("resonances", "res", "res_list"):
        if hasattr(step, attr):
            obj = getattr(step, attr)
            if isinstance(obj, dict):
                return list(obj.values())
            return list(obj)
    if hasattr(step, "get_resonances"):
        try:
            return list(step.get_resonances())
        except TypeError:
            # falls get_resonances(event_data) erwartet
            return list(step.get_resonances())
    return []

any_printed = False

for ci, chain in enumerate(chains):
    steps = iter_steps(chain)
    print(f"chain {ci}: steps={len(steps)}")
    for si, step in enumerate(steps):
        print("  step", si, "type:", type(step).__name__, "outs:", getattr(step, "outs", None))
        print("   res-attr candidates:", [a for a in dir(step) if "res" in a.lower()])
        res_list = iter_resonances(step)
        if not res_list:
            print("   (keine Resonanzen in diesem Schritt)")
            continue
        for r in res_list:
            # Manche Versionen brauchen (data, data) statt nur (data)
            try:
                amp = r.get_amp(data)
            except TypeError:
                amp = r.get_amp(data, data)
            name = getattr(r, "name", repr(r))
            # Bei Single-Event ggf. nur Skalar:
            try:
                val = amp.numpy()
            except AttributeError:
                val = amp
            print(f"   {name}: {val}")
            any_printed = True

print("printed_any:", any_printed)

chains type/len: <class 'list'> 13
chain 0: steps=3
  step 0 type: HelicityDecay outs: (X(3872), K)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schritt)
  step 1 type: HelicityDecay outs: (Dst, D)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schritt)
  step 2 type: HelicityDecay outs: (D0, pi)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schritt)
chain 1: steps=3
  step 0 type: HelicityDecay outs: (X(3915)(0-), K)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schritt)
  step 1 type: HelicityDecay outs: (Dst, D)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schritt)
  step 2 type: HelicityDecay outs: (D0, pi)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schritt)
chain 2: steps=3
  step 0 type: HelicityDecay outs: (chi(c2)(3930), K)
   res-attr candidates: ['below_threshold']
   (keine Resonanzen in diesem Schri

In [296]:
amp_model = config.get_amplitude()
dg = amp_model.decay_group

chains = getattr(dg, "decay_chains", None) or getattr(dg, "chains", None)
n_chains = len(chains)
print("n_chains:", n_chains)

per_chain_amps = []

# 1) Bevorzugt: faktorisiert je Kette
if hasattr(dg, "get_factor_angle_amp"):
    try:
        pcs = dg.get_factor_angle_amp(data, data)  # viele Versionen erwarten (data, data)
    except TypeError:
        pcs = dg.get_factor_angle_amp(data)
    # pcs ist i.d.R. eine Liste Länge n_chains; jedes Element komplex (Batch-Länge)
    for i, a in enumerate(pcs):
        a_np = a.numpy() if hasattr(a, "numpy") else np.asarray(a)
        per_chain_amps.append(a_np)
        print(f"chain {i} amp (len={len(a_np)}):", a_np[:3])
# 2) Alternativ: direkte per-Chain-Funktion
elif hasattr(dg, "get_chain_amp"):
    for i in range(n_chains):
        try:
            a = dg.get_chain_amp(i, data, data)
        except TypeError:
            a = dg.get_chain_amp(i, data)
        a_np = a.numpy() if hasattr(a, "numpy") else np.asarray(a)
        per_chain_amps.append(a_np)
        print(f"chain {i} amp (len={len(a_np)}):", a_np[:3])
# 3) Fallback: manche Versionen bieten eine Liste aller Ketten-Amps
elif hasattr(dg, "get_amp_list"):
    try:
        lst = dg.get_amp_list(data, data)
    except TypeError:
        lst = dg.get_amp_list(data)
    for i, a in enumerate(lst):
        a_np = a.numpy() if hasattr(a, "numpy") else np.asarray(a)
        per_chain_amps.append(a_np)
        print(f"chain {i} amp (len={len(a_np)}):", a_np[:3])
else:
    print("Keine per-Chain-API gefunden. Bitte gib mir die Methodenliste aus:")
    print([m for m in dir(dg) if ("amp" in m.lower() or "chain" in m.lower() or "factor" in m.lower())])

n_chains: 13
chain 0 amp (len=1): [[[[[[[[[0.32168846+0.j]]]]]]





   [[[[[[0.01650126+0.j]]]]]]]]]
chain 1 amp (len=1): [[[[[[[[[0.4144638+0.j]]]]]]]]]
chain 2 amp (len=1): [[[[[[[[[0.+0.37300291j]]]]]]]]]
chain 3 amp (len=1): [[[[[[[[[0.32168846+0.j]]]]]]





   [[[[[[0.01650126+0.j]]]]]]]]]
chain 4 amp (len=1): [[[[[[[[[0.32168846+0.j]]]]]]





   [[[[[[0.01650126+0.j]]]]]]]]]
chain 5 amp (len=1): [[[[[[[[[0.-0.51816686j]]]]]]]]]
chain 6 amp (len=1): [[[[[[[[[0.32168846+0.j]]]]]]





   [[[[[[0.01650126+0.j]]]]]]]]]
chain 7 amp (len=1): [[[[[[[[[0.4144638+0.j]]]]]]]]]
chain 8 amp (len=1): [[[[[[[[[0.32168846+0.j]]]]]]]]]
chain 9 amp (len=1): [[[[[[[[[0.4144638+0.j]]]]]]]]]
chain 10 amp (len=1): [[[[[[[[[0.-0.51816686j]]]]]]]]]
chain 11 amp (len=1): [[[[[[[[[0.57837036+0.j]]]]]]]]]
chain 12 amp (len=1): [[[[[[[[[-0.29818786+0.j        ]]]]]]]






  [[[[[[[ 0.        +0.37968393j]]]]]]]






  [[[[[[[ 0.31725527+0.j        ]]]]]]]]]


In [None]:
res_names = []
res_amps = {}

for chain in amp_model.chains:
    for decay in chain:
        for res in decay.resonances:
            name = res.name
            res_names.append(name)
            res_amps[name] = res.get_amp(data).numpy()

df = pd.DataFrame(res_amps)
print(df.head())

AttributeError: 'AmplitudeModel' object has no attribute 'chains'

In [188]:
"""
Functions
"""

def print_structure(d, indent=0):
    prefix = "  " * indent
    if isinstance(d, dict):
        for k, v in d.items():
            print(f"{prefix}{k}:")
            print_structure(v, indent + 1)
    elif isinstance(d, tf.Tensor):
        print(f"{prefix}tf.Tensor(shape={d.shape}, dtype={d.dtype})")
    elif isinstance(d, np.ndarray):
        print(f"{prefix}np.ndarray(shape={d.shape}, dtype={d.dtype})")
    else:
        print(f"{prefix}{type(d)}")


def extract_single_event(data):
    if isinstance(data, dict):
        return {k: extract_single_event(v) for k, v in data.items()}
    elif isinstance(data, (tf.Tensor, np.ndarray)):
        return data[:1]
    else:
        return data
    
    
def complex_amplitude(single_event):
    return config.get_amplitude().decay_group.get_amp(single_event).numpy().flatten()[0]


def list_resonances(config):

    warnings.filterwarnings("ignore")

    decay_chains = config.get_decay()

    used_res = set()

    for chain in decay_chains:
        for decay in chain:
            used_res.add(decay.core.name)

    return sorted(used_res)

In [205]:
"Filtering out one event"
phsp_single = extract_single_event(phsp)

print_structure(phsp_single)

particle:
  pi:
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  (D0, K, pi):
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  (D0, pi):
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  (D, D0, pi):
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  (D, K):
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  D0:
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  K:
    m:
      tf.Tensor(shape=(1,), dtype=<dtype: 'float64'>)
    p:
      tf.Tensor(shape=(1, 4), dtype=<dtype: 'float64'>)
  Bp:
    m:
      tf.Tensor(shape=(1,), dtype=<dtype

In [206]:
"Calculating the complex amplitude"
A_single = complex_amplitude(phsp_single)

print(A_single)

(0.04850502291779847-0.09498262548371192j)


In [207]:
"Display all resonances involved"
Resonances = list_resonances(config)

for i in Resonances:
    print(i)

Bp
Dst
NR(0-)SPm
NR(0-)SPp
NR(1+)PSp
NR(1-)PPm
Psi(4040)
X(3872)
X(3915)(0-)
X(3940)(1+)
X(3993)
X(4300)
X0(2900)
X1(2900)
chi(c2)(3930)


In [200]:
print(config.decay_group())

AttributeError: 'ConfigLoader' object has no attribute 'decay_group'