# EMPCA training (frequency domain)

Step 1: imports + load traces + load PSD.

## Imports

In [4]:
import numpy as np
import h5py

# Import from the provided implementation
from empca_TCY import EMPCA, ti_rfft  # paper-style transform + EM-PCA :contentReference[oaicite:4]{index=4}

# ----------------------------
# 1) Baseline correction
# ----------------------------
def baseline_correct_per_trace(X_time, pretrigger=4000, method="mean"):
    """
    X_time: (n_events, n_time)
    pretrigger: number of first samples assumed pulse-free
    method: "mean" or "median" baseline estimate per trace
    Returns:
      X0: baseline-subtracted traces (same shape)
      baseline: (n_events,) baseline values
    """
    X_time = np.asarray(X_time, dtype=np.float64)
    if X_time.ndim != 2:
        raise ValueError(f"X_time must be 2D; got {X_time.shape}")
    if not (1 <= pretrigger <= X_time.shape[1]):
        raise ValueError("pretrigger must be within [1, n_time]")

    pre = X_time[:, :pretrigger]
    if method == "mean":
        baseline = np.mean(pre, axis=1)
    elif method == "median":
        baseline = np.median(pre, axis=1)
    else:
        raise ValueError("method must be 'mean' or 'median'")

    X0 = X_time - baseline[:, None]
    return X0, baseline


# ----------------------------
# 2) Frequency-domain shift-invariant transform
# ----------------------------
def to_shift_invariant_spectrum(X_time):
    """
    Applies the paper’s phase-difference transform implemented as ti_rfft.
    Output is complex with shape (n_events, n_freq).
    """
    X_tilde = ti_rfft(X_time)  # complex, shift-invariant frequency representation :contentReference[oaicite:5]{index=5}
    if X_tilde.ndim != 2:
        raise RuntimeError("Unexpected output shape from ti_rfft")
    return X_tilde


# ----------------------------
# 3) Build weights matrix from PSD
# ----------------------------
def make_inverse_psd_weights(noise_psd, eps=1e-18):
    """
    noise_psd: (n_freq,) one-sided PSD matching rfft bins.
    Returns W: (n_freq, n_freq) diagonal matrix with entries 1/(PSD+eps).
    """
    noise_psd = np.asarray(noise_psd, dtype=np.float64)
    if noise_psd.ndim != 1:
        raise ValueError(f"noise_psd must be 1D; got {noise_psd.shape}")
    inv = 1.0 / (noise_psd + eps)
    W = np.diag(inv)
    return W


# ----------------------------
# 4) Train EM-PCA
# ----------------------------
def train_empca_paper_style(
    X_time,
    noise_psd,
    n_comp=2,            # paper used 2 components for the estimator
    pretrigger=4000,
    baseline_method="mean",
    n_iter=50,
    mode="fast",         # fast corresponds to per-component approximation in the paper
    window=15, polyord=3, deriv=0,
):
    # 4.1 baseline subtraction
    X0, baseline = baseline_correct_per_trace(
        X_time, pretrigger=pretrigger, method=baseline_method
    )

    # 4.2 shift-invariant frequency representation
    X_tilde = to_shift_invariant_spectrum(X0)  # (n_events, n_freq) complex

    # 4.3 ensure PSD matches frequency dimension
    n_freq = X_tilde.shape[1]
    if len(noise_psd) != n_freq:
        raise ValueError(
            f"PSD length mismatch: PSD has {len(noise_psd)} bins but data has {n_freq}. "
            f"For n_time={X_time.shape[1]}, expected n_freq = n_time//2 + 1."
        )

    # 4.4 weights W = diag(1/PSD)
    W = make_inverse_psd_weights(noise_psd)

    # 4.5 EM-PCA fit (weighted, smoothed templates)
    pca = EMPCA(n_comp=n_comp)  # :contentReference[oaicite:6]{index=6}
    chi2s = pca.fit(
        X_tilde, W,
        n_iter=n_iter,
        mode=mode,
        window=window, polyord=polyord, deriv=deriv,
        verbose=False
    )

    return pca, chi2s, baseline, X_tilde, W


# ----------------------------
# 5) Pulse-height estimator: coefficient norm (paper)
# ----------------------------
def pca_amplitude_estimator(coeff):
    """
    coeff: (n_events, n_comp) complex or real coefficients
    Paper uses the (summed) norm of coefficients for n_comp=2:
      A_hat = sqrt(|b1|^2 + |b2|^2)
    Generalizes naturally to n_comp>2.
    """
    coeff = np.asarray(coeff)
    return np.sqrt(np.sum(np.abs(coeff)**2, axis=1))




## Load traces

In [7]:
# Load matched traces and rqs from HDF5
traces_h5 = "k_alpha_traces.h5"
rqs_h5 = "k_alpha_rqs.h5"

with h5py.File(traces_h5, "r") as f:
    X = f["traces"][:].astype(np.float64)

with h5py.File(rqs_h5, "r") as f:
    rqs = f["rqs"][:]

print("X shape:", X.shape)
print("rqs shape:", rqs.shape)


X shape: (4358, 32768)
rqs shape: (4358,)


## Load noise PSD

In [6]:
psd_path = "/ceph/dwong/delight/noise_psd_xray.npy"
psd_arr = np.load(psd_path)
psd = psd_arr[1] if psd_arr.ndim == 2 and psd_arr.shape[0] == 2 else psd_arr
psd = psd.astype(np.float64)

print("PSD shape:", psd.shape)


PSD shape: (16385,)


In [8]:
# ----------------------------
# Example usage (X, psd loaded above)
# ----------------------------
training_params = {
    "n_comp": 2,
    "pretrigger": 4000,
    "baseline_method": "mean",
    "n_iter": 30,
    "mode": "fast",
    "window": 15,
    "polyord": 3,
    "deriv": 0,
}

pca, chi2s, baselines, X_tilde, W = train_empca_paper_style(
    X, psd,
    **training_params
)

coeff = pca.coeff
amp_raw = pca_amplitude_estimator(coeff)


# Save model artifact in a common ML format (pickle)
import pickle
model_path = "empca_frequency_model.pkl"
model_artifact = {
    "pca": pca,
    "training_params": training_params,
    "psd_path": psd_path,
}
with open(model_path, "wb") as f:
    pickle.dump(model_artifact, f)

print("saved model to", model_path)


  3%|▎         | 1/30 [15:04<7:17:22, 904.91s/it]


KeyboardInterrupt: 