# Joint Training of Three Normalizing Flows with Constraint

## Overview

This notebook implements **joint training** of all three normalizing flows:
- `flow_flavor`: Models the D → Ksππ flavor-tagged amplitude
- `flow_even`: Models the CP-even combination
- `flow_odd`: Models the CP-odd combination

Unlike the constrained training in `train_flow_odd_constrained.ipynb` where only `flow_odd` is trained,
here all three flows are trained simultaneously with:
1. Individual NLL losses for each flow
2. A shared constraint penalty enforcing $|C| \leq |A||\bar{A}|$

## Loss Function

$$\mathcal{L} = \underbrace{-\langle \log p_{\text{flavor}}(x_{\text{flavor}}) \rangle}_{\text{NLL flavor}} + \underbrace{-\langle \log p_{\text{even}}(x_{\text{even}}) \rangle}_{\text{NLL even}} + \underbrace{-\langle \log p_{\text{odd}}(x_{\text{odd}}) \rangle}_{\text{NLL odd}} + \lambda \underbrace{\langle \max(|C| - \text{abJ}, 0)^2 \rangle}_{\text{constraint}}$$

where:
- $C = \frac{\Gamma_+ \cdot p_{\text{even}} - \Gamma_- \cdot p_{\text{odd}}}{\Gamma_+ + \Gamma_-}$
- $\text{abJ} = |A(s_{12}, s_{13})| \cdot |A(s_{13}, s_{12})| = \sqrt{p_{\text{flavor}}(m', \theta')} \cdot \sqrt{p_{\text{flavor}}(m'_{\text{swap}}, \theta'_{\text{swap}})}$

The swapped SDP coordinates $(m'_{\text{swap}}, \theta'_{\text{swap}})$ are computed via the standard
`sdp_to_dp → swap → dp_to_sdp` transformation.

---
## 1. Imports and Setup

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
import os
from pathlib import Path
from datetime import datetime

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from nflows.flows import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms import CompositeTransform, RandomPermutation
from nflows.transforms.coupling import PiecewiseRationalQuadraticCouplingTransform
from nflows.transforms.base import Transform
from nflows.transforms import Sigmoid, InverseTransform

# Plotting setup
sns.set()
sns.set_style("ticks")
sns.set_context("paper", font_scale=1.5)
plt.rcParams['text.usetex'] = False
plt.rcParams['font.size'] = 12

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


In [3]:
from Amplitude import DKpp, BKpp, DalitzSample, AmpSample, SquareDalitzPlot2

# --- Particle masses ---
mD, mKs, mpi = 1.86483, 0.497611, 0.13957018
SDP = SquareDalitzPlot2(mD, mKs, mpi, mpi)

# Initialize amplitude model
dkpp = DKpp()
sdp_obj = SquareDalitzPlot2(dkpp.M(), dkpp.m1(), dkpp.m2(), dkpp.m3())

---
## 2. Coordinate Transformation Functions

Functions to convert between Square Dalitz Plot (SDP) and standard Dalitz Plot (DP) coordinates,
and to compute swapped coordinates for the constraint.

In [4]:
def sdp_to_dp(points_sdp, sdp_obj, idx=(1,2,3)):
    """
    Convert Square Dalitz Plot coordinates to Dalitz Plot coordinates.
    
    Parameters
    ----------
    points_sdp : ndarray, shape (N, 2)
        Points in SDP coordinates (m', θ')
    sdp_obj : SquareDalitzPlot2
        SDP object with mass definitions
    idx : tuple
        Particle indices (i, j, k)
    
    Returns
    -------
    ndarray, shape (N, 2)
        Points in DP coordinates (s_ij, s_ik)
    """
    i, j, k = idx
    out = np.empty_like(points_sdp, dtype=float)
    for n, (mp, th) in enumerate(points_sdp):
        sij, sik = sdp_obj.M_from_MpT(mp, th, i, j, k)
        out[n, 0] = sij
        out[n, 1] = sik
    return out


def dp_to_sdp(points_dp, sdp_obj, idx=(1,2,3)):
    """
    Convert Dalitz Plot coordinates to Square Dalitz Plot coordinates.
    
    Parameters
    ----------
    points_dp : ndarray, shape (N, 2)
        Points in DP coordinates (s_ij, s_ik)
    sdp_obj : SquareDalitzPlot2
        SDP object with mass definitions
    idx : tuple
        Particle indices (i, j, k)
    
    Returns
    -------
    ndarray, shape (N, 2)
        Points in SDP coordinates (m', θ')
    """
    i, j, k = idx
    s12 = points_dp[:, 0]
    s13 = points_dp[:, 1]
    mp = np.vectorize(lambda a, b: sdp_obj.MpfromM(a, b, i, j, k), otypes=[float])(s12, s13)
    tp = np.vectorize(lambda a, b: sdp_obj.TfromM(a, b, i, j, k), otypes=[float])(s12, s13)
    return np.column_stack([mp, tp])


def swap_to_other_pair_sdp(s12, s13, sdp_obj, pair_swap=(1,3,2)):
    """
    Convert Dalitz point to SDP coordinates for a different particle pairing.
    
    This computes the SDP coordinates when swapping s12 <-> s13,
    which corresponds to evaluating A(s13, s12) instead of A(s12, s13).
    
    Parameters
    ----------
    s12, s13 : ndarray
        Dalitz plot coordinates
    sdp_obj : SquareDalitzPlot2
        SDP object with mass definitions
    pair_swap : tuple
        New particle indices after swap
    
    Returns
    -------
    ndarray, shape (N, 2)
        Swapped SDP coordinates (m'_swap, θ'_swap)
    """
    i2, j2, k2 = pair_swap
    s12 = np.asarray(s12)
    s13 = np.asarray(s13)
    mp13 = np.empty_like(s12)
    th13 = np.empty_like(s12)
    for n in range(s12.size):
        mp13[n] = sdp_obj.MpfromM(s13[n], s12[n], i2, j2, k2)
        th13[n] = sdp_obj.TfromM(s13[n], s12[n], i2, j2, k2)
    return np.column_stack([mp13, th13])


def compute_swapped_sdp_coords(points_sdp, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2)):
    """
    Compute swapped SDP coordinates for a batch of points.
    
    Pipeline: SDP -> DP -> swap -> SDP
    
    Parameters
    ----------
    points_sdp : ndarray, shape (N, 2)
        Original SDP coordinates (m', θ')
    sdp_obj : SquareDalitzPlot2
        SDP object
    idx : tuple
        Original particle indices
    pair_swap : tuple
        Swapped particle indices
    
    Returns
    -------
    ndarray, shape (N, 2)
        Swapped SDP coordinates
    """
    # SDP -> DP
    dp = sdp_to_dp(points_sdp, sdp_obj, idx=idx)
    s12, s13 = dp[:, 0], dp[:, 1]
    
    # Swap and convert back to SDP
    swapped_sdp = swap_to_other_pair_sdp(s12, s13, sdp_obj, pair_swap=pair_swap)
    
    return swapped_sdp

---
## 3. Model Architecture

Neural spline flow architecture (same as in other notebooks).

In [5]:
class MLP(nn.Module):
    """
    Multi-layer perceptron for conditioning in coupling layers.
    """
    def __init__(self, in_features, out_features,
                 hidden=64, layers=2, output_scale=0.30):
        super().__init__()
        feats = [nn.Linear(in_features, hidden), nn.SiLU()]
        for _ in range(layers - 1):
            feats += [nn.Linear(hidden, hidden), nn.SiLU()]
        self.backbone = nn.Sequential(*feats)
        self.head = nn.Linear(hidden, out_features)

        nn.init.zeros_(self.head.weight)
        nn.init.zeros_(self.head.bias)

        self.output_scale = output_scale

    def forward(self, x, context=None):
        h = self.backbone(x)
        return self.head(h) * self.output_scale

In [6]:
def create_flow(on_unit_box=True, num_flows=8, hidden_features=64, num_bins=8, device=None):
    """
    Create a normalizing flow model using Neural Spline Flows.
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    dim = 2
    transforms = []

    if on_unit_box:
        sigmoid = Sigmoid()
        if hasattr(sigmoid, 'temperature') and isinstance(sigmoid.temperature, torch.Tensor):
            sigmoid.temperature = sigmoid.temperature.to(device)
        if hasattr(sigmoid, 'eps') and isinstance(sigmoid.eps, torch.Tensor):
            sigmoid.eps = sigmoid.eps.to(device)
        transforms.append(InverseTransform(sigmoid))

    masks = [torch.tensor([1, 0], dtype=torch.bool),
             torch.tensor([0, 1], dtype=torch.bool)]

    for i in range(num_flows):
        mask = masks[i % 2]

        def conditioner(in_features, out_features, _hidden=hidden_features):
            return MLP(in_features, out_features, hidden=_hidden, layers=2)

        transforms.append(
            PiecewiseRationalQuadraticCouplingTransform(
                mask=mask,
                transform_net_create_fn=conditioner,
                num_bins=num_bins,
                tails="linear",
                tail_bound=5.0,
                apply_unconditional_transform=False,
            )
        )
        transforms.append(RandomPermutation(features=dim))

    transform = CompositeTransform(transforms)
    base = StandardNormal(shape=[dim])

    return Flow(transform, base)

---
## 4. Dataset Classes

We need datasets that provide:
1. Batches from all three data sources simultaneously
2. Precomputed swapped coordinates for the constraint penalty

In [7]:
class DalitzDataset(Dataset):
    """
    Simple PyTorch Dataset for Dalitz plot coordinates.
    """
    def __init__(self, data):
        if isinstance(data, np.ndarray):
            self.data = torch.FloatTensor(data)
        else:
            self.data = data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]

In [8]:
class JointDalitzDataset(Dataset):
    """
    Dataset that provides synchronized batches from three data sources,
    plus precomputed swapped coordinates for constraint evaluation.
    
    Each __getitem__ returns:
        (x_flavor, x_even, x_odd, x_odd_swapped)
    
    where x_odd_swapped are the SDP coordinates after swapping s12 <-> s13.
    """
    def __init__(self, data_flavor, data_even, data_odd, data_odd_swapped):
        self.data_flavor = torch.FloatTensor(data_flavor) if isinstance(data_flavor, np.ndarray) else data_flavor
        self.data_even = torch.FloatTensor(data_even) if isinstance(data_even, np.ndarray) else data_even
        self.data_odd = torch.FloatTensor(data_odd) if isinstance(data_odd, np.ndarray) else data_odd
        self.data_odd_swapped = torch.FloatTensor(data_odd_swapped) if isinstance(data_odd_swapped, np.ndarray) else data_odd_swapped
        
        # Use minimum size for epoch length
        self.epoch_len = min(len(self.data_flavor), len(self.data_even), len(self.data_odd))
        
        # Create indices for each dataset
        self._shuffle_indices()
    
    def _shuffle_indices(self):
        """Shuffle indices for each dataset."""
        self.idx_flavor = torch.randperm(len(self.data_flavor))
        self.idx_even = torch.randperm(len(self.data_even))
        # odd and odd_swapped share the same indices (they're paired)
        self.idx_odd = torch.randperm(len(self.data_odd))
    
    def __len__(self):
        return self.epoch_len
    
    def __getitem__(self, idx):
        # Wrap indices for datasets larger than epoch_len
        idx_f = self.idx_flavor[idx % len(self.data_flavor)]
        idx_e = self.idx_even[idx % len(self.data_even)]
        idx_o = self.idx_odd[idx % len(self.data_odd)]
        
        return (self.data_flavor[idx_f], 
                self.data_even[idx_e], 
                self.data_odd[idx_o],
                self.data_odd_swapped[idx_o])  # Same index as odd

---
## 5. Compute Gamma Values

We need $\Gamma_+$ and $\Gamma_-$ for the constraint formula.

In [9]:
def _finite_pos(x, eps=1e-14):
    """
    Ensure array has finite positive values for numerical stability.
    """
    x = np.asarray(x)
    x = np.where(np.isfinite(x), x, 0.0)
    return np.maximum(x, eps)


def compute_Gamma_pm_from_amplitude(dkpp_model, sdp_obj, idx=(1,2,3), nx=1000, ny=1000):
    """
    Compute Gamma_plus and Gamma_minus from the amplitude model using numerical integration.
    
    Returns
    -------
    Gamma_plus, Gamma_minus : float
    """
    u = np.linspace(0, 1, nx)
    v = np.linspace(0, 1, ny)
    U, V = np.meshgrid(u, v, indexing='xy')
    pts_sdp = np.column_stack([U.ravel(), V.ravel()])

    # Convert to Dalitz coordinates
    S = sdp_to_dp(pts_sdp, sdp_obj, idx=idx)
    s12 = S[:, 0]
    s13 = S[:, 1]

    # Compute amplitudes
    A12  = dkpp_model.full(np.column_stack([s12, s13]))
    A13s = dkpp_model.full(np.column_stack([s13, s12]))

    A_plus  = A12 + A13s  # CP-even combination
    A_minus = A12 - A13s  # CP-odd combination

    # Compute Jacobian for integration measure
    invJ = np.array([1.0 / sdp_obj.jacobian(s12[n], s13[n], *idx) 
                     for n in range(len(s12))])
    invJ = np.where(np.isfinite(invJ), invJ, 0.0)

    I_plus  = np.abs(A_plus)**2  * invJ
    I_minus = np.abs(A_minus)**2 * invJ

    du = 1.0/(nx-1)
    dv = 1.0/(ny-1)
    dA = du * dv

    Gamma_plus  = np.sum(I_plus)  * dA
    Gamma_minus = np.sum(I_minus) * dA

    return Gamma_plus, Gamma_minus

In [10]:
print("Computing Gamma_plus and Gamma_minus from amplitude model...")
gamma_p, gamma_m = compute_Gamma_pm_from_amplitude(dkpp, sdp_obj, idx=(1,2,3), nx=500, ny=500)
print(f"Gamma_plus  = {gamma_p:.4f}")
print(f"Gamma_minus = {gamma_m:.4f}")
print(f"Ratio Gamma_+/Gamma_- = {gamma_p/gamma_m:.4f}")

Computing Gamma_plus and Gamma_minus from amplitude model...


  dmp_dsij = -(1.0 / np.pi) * (1.0 / np.sqrt(one_minus_x2)) * (1.0 / (denom_m * np.sqrt(mijSq)))
  cosHel = -(mikSq - self.mSq[i] - self.mSq[k] - 2.0 * EiCmsij * EkCmsij) / (2.0 * self.qi * self.qk)


Gamma_plus  = 949.7953
Gamma_minus = 2008.2305
Ratio Gamma_+/Gamma_- = 0.4730


---
## 6. Joint Training Function

The core training loop that:
1. Computes NLL for each flow using its own data (flavor, even, odd)
2. Computes the constraint penalty on **B-decay points** (not random points)
3. Uses precomputed swapped coordinates for abJ calculation

**Key insight**: The constraint should be enforced on points that matter for B→DK analysis,
not on uniformly sampled points. We use actual B-decay MC/data points for constraint evaluation.

In [28]:
def train_joint_flows_with_B_constraint(
    flow_flavor,
    flow_even,
    flow_odd,
    train_loader,
    constraint_points,
    constraint_points_swapped,
    Gamma_plus=1.0,
    Gamma_minus=1.0,
    lam=10.0,
    warmup_epochs=20,
    num_epochs=200,
    lr_odd=1e-3,
    lr_even=1e-4,
    lr_flavor=1e-4,
    constraint_batch_size=10000,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    """
    Train all three flows jointly with constraint evaluated on B-decay points.
    
    Loss = NLL_flavor + NLL_even + NLL_odd + λ * constraint_penalty
    
    The constraint penalty is evaluated on B-decay points (Bp, Bm, or MC),
    ensuring the constraint |C| <= abJ is satisfied where it matters most.
    
    Parameters
    ----------
    flow_flavor, flow_even, flow_odd : Flow
        The three normalizing flow models.
    train_loader : DataLoader
        Must yield (x_flavor, x_even, x_odd, x_odd_swapped) tuples for NLL computation.
    constraint_points : torch.Tensor, shape (N, 2)
        B-decay points (SDP coordinates) where constraint should be enforced.
        Can be Bp + Bm data, or MC points, or both combined.
    constraint_points_swapped : torch.Tensor, shape (N, 2)
        Swapped SDP coordinates for constraint_points.
    Gamma_plus, Gamma_minus : float
        CP-even/odd decay widths.
    lam : float
        Maximum penalty strength.
    warmup_epochs : int
        Number of epochs to linearly ramp lambda from 0 to lam.
    num_epochs : int
        Total training epochs.
    lr_odd, lr_even, lr_flavor : float
        Learning rates for each flow.
    constraint_batch_size : int
        Batch size for constraint evaluation (to avoid memory issues).
    device : str
        Torch device.
    
    Returns
    -------
    flows : tuple (flow_flavor, flow_even, flow_odd)
    history : dict
    """
    flow_flavor.to(device)
    flow_even.to(device)
    flow_odd.to(device)
    
    # Move constraint points to device
    constraint_points = constraint_points.to(device)
    constraint_points_swapped = constraint_points_swapped.to(device)
    n_constraint = len(constraint_points)
    
    # Separate parameter groups with different learning rates
    optimizer = torch.optim.Adam([
        {'params': flow_odd.parameters(), 'lr': lr_odd},
        {'params': flow_even.parameters(), 'lr': lr_even},
        {'params': flow_flavor.parameters(), 'lr': lr_flavor},
    ])
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.65, patience=5, min_lr=1e-6
    )
    
    denom = Gamma_minus + Gamma_plus
    
    history = {
        'total_loss': [],
        'nll_flavor': [],
        'nll_even': [],
        'nll_odd': [],
        'penalty_loss': [],
        'violation_frac': [],
        'learning_rate_odd': [],
        'learning_rate_even': [],
        'learning_rate_flavor': [],
        'lambda': [],
        'epochs_trained': num_epochs
    }
    
    for epoch in tqdm(range(1, num_epochs + 1), desc="Training", ncols=80):
        flow_flavor.train()
        flow_even.train()
        flow_odd.train()
        
        total_nll_flavor = 0.0
        total_nll_even = 0.0
        total_nll_odd = 0.0
        total_pts = 0
        
        # Linear warmup: ramp lambda from 0 to lam over warmup_epochs
        current_lam = lam * min(epoch / warmup_epochs, 1.0)
        
        # ===== Part 1: NLL training on flavor/even/odd data =====
        for x_flavor, x_even, x_odd, _ in train_loader:
            x_flavor = x_flavor.to(device)
            x_even = x_even.to(device)
            x_odd = x_odd.to(device)
            B = x_odd.size(0)
            
            # NLL terms (each flow gets its own data)
            nll_flavor = -flow_flavor.log_prob(x_flavor).mean()
            nll_even = -flow_even.log_prob(x_even).mean()
            nll_odd = -flow_odd.log_prob(x_odd).mean()
            
            # NLL-only backward (no constraint yet)
            nll_loss = nll_flavor + nll_even + nll_odd
            
            optimizer.zero_grad()
            nll_loss.backward()
            optimizer.step()
            
            total_nll_flavor += nll_flavor.item() * B
            total_nll_even += nll_even.item() * B
            total_nll_odd += nll_odd.item() * B
            total_pts += B
        
        # ===== Part 2: Constraint penalty on B-decay points =====
        total_pen = 0.0
        total_viol = 0
        
        # Process constraint points in batches
        for start_idx in range(0, n_constraint, constraint_batch_size):
            end_idx = min(start_idx + constraint_batch_size, n_constraint)
            x_B = constraint_points[start_idx:end_idx]
            x_B_swapped = constraint_points_swapped[start_idx:end_idx]
            B_c = len(x_B)
            
            # Evaluate flows at B-decay points
            log_p_even = flow_even.log_prob(x_B)
            log_p_odd = flow_odd.log_prob(x_B)
            p_even = torch.exp(log_p_even)
            p_odd = torch.exp(log_p_odd)
            
            # abJ = sqrt(p_flavor(x)) * sqrt(p_flavor(x_swapped))
            log_p_flavor_orig = flow_flavor.log_prob(x_B)
            log_p_flavor_swap = flow_flavor.log_prob(x_B_swapped)
            abJ = torch.exp(0.5 * (log_p_flavor_orig + log_p_flavor_swap))
            
            # C = (Γ+ * p_even - Γ- * p_odd) / (Γ+ + Γ-)
            C = (Gamma_plus * p_even - Gamma_minus * p_odd) / denom
            
            # Penalty: max(|C| - abJ, 0)^2
            violation = torch.clamp(torch.abs(C) - abJ, min=0.0)
            penalty = (violation ** 1).mean()
            
            # Constraint backward
            constraint_loss = current_lam * penalty
            optimizer.zero_grad()
            constraint_loss.backward()
            optimizer.step()
            
            total_pen += penalty.item() * B_c
            total_viol += (violation > 0).sum().item()
        
        # Epoch averages
        avg_nll_flavor = total_nll_flavor / total_pts
        avg_nll_even = total_nll_even / total_pts
        avg_nll_odd = total_nll_odd / total_pts
        avg_pen = total_pen / n_constraint
        viol_frac = total_viol / n_constraint
        
        total_loss = avg_nll_flavor + avg_nll_even + avg_nll_odd + current_lam * avg_pen
        
        history['total_loss'].append(total_loss)
        history['nll_flavor'].append(avg_nll_flavor)
        history['nll_even'].append(avg_nll_even)
        history['nll_odd'].append(avg_nll_odd)
        history['penalty_loss'].append(avg_pen)
        history['violation_frac'].append(viol_frac)
        history['learning_rate_odd'].append(optimizer.param_groups[0]['lr'])
        history['learning_rate_even'].append(optimizer.param_groups[1]['lr'])
        history['learning_rate_flavor'].append(optimizer.param_groups[2]['lr'])
        history['lambda'].append(current_lam)
        
        scheduler.step(total_loss)
        
        if epoch % 5 == 0 or epoch == 1:
            print(f"[{epoch:03d}] NLL: flav={avg_nll_flavor:.4f} even={avg_nll_even:.4f} "
                  f"odd={avg_nll_odd:.4f} | Pen={avg_pen:.2e} Viol={viol_frac:.4f} λ={current_lam:.1f}")
    
    return (flow_flavor, flow_even, flow_odd), history

---
## 7. Load Data and Precompute Swapped Coordinates

**Choose ONE of the following cells to run:**
- **Cell 7a (Test)**: Uses small test datasets (100k events) for quick testing
- **Cell 7b (Production)**: Uses full datasets with pretrained weights

In [22]:
# =============================================================================
# 7a. TEST CONFIGURATION - Small datasets for quick testing
# =============================================================================
# Run this cell for testing on local machine without full data

data_flavor_path = "D_Kspipi_SDP_test.npy"       # 100k flavor events
data_even_path = "D_Kspipi_even_SDP_test.npy"    # 100k CP-even events
data_odd_path = "D_Kspipi_odd_SDP_test.npy"      # 100k CP-odd events

print("=== TEST CONFIGURATION ===")
print("Loading test datasets (100k events each)...")
data_flavor = np.load(data_flavor_path)
data_even = np.load(data_even_path)
data_odd = np.load(data_odd_path)

# Use all test data (no subsampling)
train_flavor = data_flavor
train_even = data_even
train_odd = data_odd

print(f"  Flavor: {len(train_flavor):,} events")
print(f"  Even:   {len(train_even):,} events")
print(f"  Odd:    {len(train_odd):,} events")

# Test configuration flags
USE_PRETRAINED = False
BATCH_SIZE = 10000
FLOW_CONFIG = {'num_flows': 16, 'hidden_features': 64, 'num_bins': 16}
TRAIN_CONFIG_OVERRIDE = {'num_epochs': 100, 'warmup_epochs': 10, 'lr_odd': 1e-3, 'lr_even': 1e-3, 'lr_flavor': 1e-3}
N_TEST_POINTS = 50000

=== TEST CONFIGURATION ===
Loading test datasets (100k events each)...
  Flavor: 100,000 events
  Even:   100,000 events
  Odd:    100,000 events


In [None]:
# =============================================================================
# 7b. PRODUCTION CONFIGURATION - Full datasets with pretrained weights
# =============================================================================
# Run this cell on the other computer with full data and pretrained models
# Skip cell 7a above if running this one

data_flavor_path = "D_Kspipi_SDP_2e7.npy"        # ~20M flavor events
data_even_path = "D_Kspipi_even_SDP_2e7.npy"     # ~20M CP-even events  
data_odd_path = "D_Kspipi_odd_SDP_2e7.npy"       # ~20M CP-odd events

print("=== PRODUCTION CONFIGURATION ===")
print("Loading full datasets...")
data_flavor = np.load(data_flavor_path)
data_even = np.load(data_even_path)
data_odd = np.load(data_odd_path)

print(f"  Flavor: {len(data_flavor):,} events")
print(f"  Even:   {len(data_even):,} events")
print(f"  Odd:    {len(data_odd):,} events")

# Subsample for training
train_size = 2_000_000
np.random.seed(42)
idx_flavor = np.random.choice(len(data_flavor), size=min(train_size, len(data_flavor)), replace=False)
idx_even = np.random.choice(len(data_even), size=min(train_size, len(data_even)), replace=False)
idx_odd = np.random.choice(len(data_odd), size=min(train_size, len(data_odd)), replace=False)

train_flavor = data_flavor[idx_flavor]
train_even = data_even[idx_even]
train_odd = data_odd[idx_odd]

print(f"\nTraining with:")
print(f"  Flavor: {len(train_flavor):,} events")
print(f"  Even:   {len(train_even):,} events")
print(f"  Odd:    {len(train_odd):,} events")

# Production configuration flags
USE_PRETRAINED = True
BATCH_SIZE = 50000
FLOW_CONFIG = {'num_flows': 12, 'hidden_features': 128, 'num_bins': 24}
TRAIN_CONFIG_OVERRIDE = {'num_epochs': 100, 'warmup_epochs': 10, 'lr_odd': 1e-3, 'lr_even': 1e-4, 'lr_flavor': 1e-4}
N_TEST_POINTS = 200000

# Pretrained model paths (adjust if needed)
PRETRAINED_FLAVOR = 'test_ensemble_2e6/trial_seed1.pth'
PRETRAINED_EVEN = 'test_ensemble_even_2e6/trial_seed1.pth'
PRETRAINED_ODD = 'test_ensemble_odd_2e6/trial_seed1.pth'

In [13]:
# For test data, use all available events (no subsampling needed)
# For larger datasets, uncomment the subsampling code below

train_flavor = data_flavor
train_even = data_even
train_odd = data_odd

# # Subsample for training (uncomment for large datasets)
# train_size = 2_000_000
# np.random.seed(42)
# idx_flavor = np.random.choice(len(data_flavor), size=min(train_size, len(data_flavor)), replace=False)
# idx_even = np.random.choice(len(data_even), size=min(train_size, len(data_even)), replace=False)
# idx_odd = np.random.choice(len(data_odd), size=min(train_size, len(data_odd)), replace=False)
# train_flavor = data_flavor[idx_flavor]
# train_even = data_even[idx_even]
# train_odd = data_odd[idx_odd]

print(f"Training with:")
print(f"  Flavor: {len(train_flavor):,} events")
print(f"  Even:   {len(train_even):,} events")
print(f"  Odd:    {len(train_odd):,} events")

Training with:
  Flavor: 100,000 events
  Even:   100,000 events
  Odd:    100,000 events


In [14]:
# Precompute swapped coordinates for train_odd (used for NLL data loader, not constraint)
print("Precomputing swapped SDP coordinates for NLL training data...")

train_odd_swapped = compute_swapped_sdp_coords(
    train_odd, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2)
)

print(f"Done. Shape: {train_odd_swapped.shape}")

Precomputing swapped SDP coordinates for NLL training data...
Done. Shape: (100000, 2)


---
## 7b. Load B-decay Data 

In [16]:
# =============================================================================
# Load B-decay data for constraint evaluation
# =============================================================================
# The npz file structure (from Gamma-fit-pipeline):
#   - 'dataP_sdp': B+ experimental data (~50k), shape (N, 2) in SDP coordinates
#   - 'dataM_sdp': B- experimental data (~50k), shape (N, 2) in SDP coordinates  
#   - 'mcM_sdp': B- Monte Carlo (~500k), shape (N, 2) in SDP coordinates
#
# We only need one MC sample (mcM_sdp) since it covers the same phase space

B_DATA_PATH = "BpBm_samples.npz"

print(f"Loading B-decay data from {B_DATA_PATH}...")
B_data = np.load(B_DATA_PATH)

# Load using the keys from Gamma-fit-pipeline
Bp_sdp = B_data['dataP_sdp']  # B+ experimental data
Bm_sdp = B_data['dataM_sdp']  # B- experimental data
MC_sdp = B_data['mcM_sdp']    # Use B- MC only (mcP_sdp covers same phase space)

print(f"\nLoaded B-decay data:")
print(f"  B+ data (dataP_sdp):  {len(Bp_sdp):,} events")
print(f"  B- data (dataM_sdp):  {len(Bm_sdp):,} events")
print(f"  MC (mcM_sdp):         {len(MC_sdp):,} events")

Loading B-decay data from BpBm_samples.npz...

Loaded B-decay data:
  B+ data (dataP_sdp):  50,000 events
  B- data (dataM_sdp):  50,000 events
  MC (mcM_sdp):         500,000 events


In [17]:
# =============================================================================
# Choose which B-decay points to use for constraint evaluation
# =============================================================================
# Options:
#   1. Use all points (Bp + Bm + MC) - most comprehensive
#   2. Use only data (Bp + Bm) - focuses on actual measurements
#   3. Use only MC - if data is blinded or not available

CONSTRAINT_MODE = "all"  # Options: "all", "data_only", "mc_only"

if CONSTRAINT_MODE == "all":
    constraint_points_np = np.vstack([Bp_sdp, Bm_sdp, MC_sdp])
    print(f"Using ALL B-decay points for constraint: {len(constraint_points_np):,}")
elif CONSTRAINT_MODE == "data_only":
    constraint_points_np = np.vstack([Bp_sdp, Bm_sdp])
    print(f"Using DATA ONLY (Bp + Bm) for constraint: {len(constraint_points_np):,}")
elif CONSTRAINT_MODE == "mc_only":
    constraint_points_np = MC_sdp
    print(f"Using MC ONLY for constraint: {len(constraint_points_np):,}")
else:
    raise ValueError(f"Unknown CONSTRAINT_MODE: {CONSTRAINT_MODE}")

# Precompute swapped coordinates for constraint points
print("\nPrecomputing swapped SDP coordinates for constraint points...")
print("(This may take a few minutes for large datasets)")

constraint_points_swapped_np = compute_swapped_sdp_coords(
    constraint_points_np, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2)
)

# Convert to torch tensors
constraint_points = torch.from_numpy(constraint_points_np.astype(np.float32))
constraint_points_swapped = torch.from_numpy(constraint_points_swapped_np.astype(np.float32))

print(f"Done. Constraint points shape: {constraint_points.shape}")

Using ALL B-decay points for constraint: 600,000

Precomputing swapped SDP coordinates for constraint points...
(This may take a few minutes for large datasets)
Done. Constraint points shape: torch.Size([600000, 2])


In [18]:
# Create joint dataset and loader
# Uses BATCH_SIZE from configuration cell (7a or 7b)

joint_dataset = JointDalitzDataset(train_flavor, train_even, train_odd, train_odd_swapped)
joint_loader = DataLoader(
    joint_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=(device == "cuda")
)

print(f"Batch size: {BATCH_SIZE}")
print(f"Batches per epoch: {len(joint_loader)}")

Batch size: 10000
Batches per epoch: 10


---
## 8. Initialize Flows from Pretrained Weights

We start from pretrained flows to:
1. Speed up convergence
2. Preserve good individual NLL performance
3. Focus training on satisfying the constraint

In [19]:
# Create flows using FLOW_CONFIG from configuration cell (7a or 7b)

flow_flavor = create_flow(**FLOW_CONFIG)
flow_even = create_flow(**FLOW_CONFIG)
flow_odd = create_flow(**FLOW_CONFIG)

n_params = sum(p.numel() for p in flow_flavor.parameters())
print(f"Flow config: {FLOW_CONFIG}")
print(f"Each flow has {n_params:,} parameters")
print(f"Total parameters: {3 * n_params:,}")

# Load pretrained weights if USE_PRETRAINED is True
if USE_PRETRAINED:
    print(f"\nLoading pretrained weights...")
    flow_flavor.load_state_dict(torch.load(PRETRAINED_FLAVOR, map_location=device))
    flow_even.load_state_dict(torch.load(PRETRAINED_EVEN, map_location=device))
    flow_odd.load_state_dict(torch.load(PRETRAINED_ODD, map_location=device))
    print(f"  Flavor: {PRETRAINED_FLAVOR}")
    print(f"  Even:   {PRETRAINED_EVEN}")
    print(f"  Odd:    {PRETRAINED_ODD}")
else:
    print("\nTraining from scratch (no pretrained weights).")

Flow config: {'num_flows': 8, 'hidden_features': 64, 'num_bins': 16}
Each flow has 58,744 parameters
Total parameters: 176,232

Training from scratch (no pretrained weights).


---
## 9. Train Joint Model

In [24]:
# Training configuration
# Uses TRAIN_CONFIG_OVERRIDE from configuration cell (7a or 7b)

train_config = {
    'Gamma_plus': gamma_p,
    'Gamma_minus': gamma_m,
    'lam': 10.0,
    'device': device,
    **TRAIN_CONFIG_OVERRIDE  # Override with test or production settings
}

print("Training configuration:")
for k, v in train_config.items():
    print(f"  {k}: {v}")

Training configuration:
  Gamma_plus: 949.7952740443438
  Gamma_minus: 2008.2304765864694
  lam: 10.0
  device: cpu
  num_epochs: 100
  warmup_epochs: 10
  lr_odd: 0.001
  lr_even: 0.001
  lr_flavor: 0.001


In [29]:
# Run training with B-decay constraint points
(flow_flavor_trained, flow_even_trained, flow_odd_trained), history = train_joint_flows_with_B_constraint(
    flow_flavor,
    flow_even,
    flow_odd,
    joint_loader,
    constraint_points,
    constraint_points_swapped,
    constraint_batch_size=50000,  # Process constraint points in batches
    **train_config
)

Training:   1%|▎                              | 1/100 [00:53<1:28:26, 53.60s/it]

[001] NLL: flav=-1.0827 even=-0.9246 odd=-0.8772 | Pen=3.12e-02 Viol=0.1751 λ=1.0


Training:   5%|█▌                             | 5/100 [04:08<1:17:15, 48.79s/it]

[005] NLL: flav=-1.0784 even=-0.9134 odd=-0.8882 | Pen=7.28e-03 Viol=0.1025 λ=5.0


Training:  10%|███                           | 10/100 [08:24<1:16:32, 51.03s/it]

[010] NLL: flav=-1.0708 even=-0.9332 odd=-0.8906 | Pen=1.61e-03 Viol=0.0470 λ=10.0


Training:  14%|████▏                         | 14/100 [11:43<1:12:04, 50.28s/it]


KeyboardInterrupt: 

In [25]:
# Run training with B-decay constraint points
(flow_flavor_trained, flow_even_trained, flow_odd_trained), history = train_joint_flows_with_B_constraint(
    flow_flavor,
    flow_even,
    flow_odd,
    joint_loader,
    constraint_points,
    constraint_points_swapped,
    constraint_batch_size=50000,  # Process constraint points in batches
    **train_config
)

Training:   1%|▎                              | 1/100 [00:57<1:34:43, 57.41s/it]

[001] NLL: flav=-1.0339 even=-0.8625 odd=-0.8347 | Pen=2.62e-02 Viol=0.2074 λ=1.0


Training:   5%|█▌                             | 5/100 [05:25<1:47:48, 68.09s/it]

[005] NLL: flav=-1.0255 even=-0.8563 odd=-0.8084 | Pen=4.32e-03 Viol=0.1412 λ=5.0


Training:  10%|███                           | 10/100 [10:39<1:38:39, 65.78s/it]

[010] NLL: flav=-0.9197 even=-0.8523 odd=-0.8025 | Pen=8.52e-03 Viol=0.2051 λ=10.0


Training:  15%|████▌                         | 15/100 [15:18<1:22:13, 58.04s/it]

[015] NLL: flav=-1.0491 even=-0.8733 odd=-0.8458 | Pen=3.94e-04 Viol=0.1012 λ=10.0


Training:  20%|██████                        | 20/100 [19:55<1:15:04, 56.30s/it]

[020] NLL: flav=-1.0580 even=-0.8925 odd=-0.8463 | Pen=4.50e-04 Viol=0.1023 λ=10.0


Training:  25%|███████▌                      | 25/100 [24:42<1:12:12, 57.76s/it]

[025] NLL: flav=-1.0611 even=-0.9036 odd=-0.8835 | Pen=2.27e-04 Viol=0.1052 λ=10.0


Training:  30%|█████████                     | 30/100 [29:17<1:04:37, 55.39s/it]

[030] NLL: flav=-1.0638 even=-0.9042 odd=-0.8753 | Pen=3.48e-04 Viol=0.1139 λ=10.0


Training:  35%|███████████▏                    | 35/100 [33:39<56:47, 52.42s/it]

[035] NLL: flav=-1.0757 even=-0.9040 odd=-0.8802 | Pen=2.72e-04 Viol=0.1110 λ=10.0


Training:  40%|████████████▊                   | 40/100 [38:10<53:59, 53.99s/it]

[040] NLL: flav=-1.0202 even=-0.9028 odd=-0.8281 | Pen=4.42e-04 Viol=0.1490 λ=10.0


Training:  45%|██████████████▍                 | 45/100 [42:56<51:07, 55.77s/it]

[045] NLL: flav=-1.0686 even=-0.9213 odd=-0.8872 | Pen=8.49e-04 Viol=0.1450 λ=10.0


Training:  50%|████████████████                | 50/100 [47:31<45:57, 55.15s/it]

[050] NLL: flav=-1.0881 even=-0.9244 odd=-0.8945 | Pen=3.47e-04 Viol=0.1221 λ=10.0


Training:  55%|█████████████████▌              | 55/100 [52:06<41:18, 55.08s/it]

[055] NLL: flav=-1.0942 even=-0.9319 odd=-0.9024 | Pen=2.14e-04 Viol=0.1032 λ=10.0


Training:  60%|███████████████████▏            | 60/100 [56:47<37:08, 55.71s/it]

[060] NLL: flav=-1.0867 even=-0.9327 odd=-0.8979 | Pen=3.76e-04 Viol=0.1214 λ=10.0


Training:  65%|███████████████████▌          | 65/100 [1:01:21<32:10, 55.14s/it]

[065] NLL: flav=-1.0973 even=-0.9369 odd=-0.9064 | Pen=2.14e-04 Viol=0.1073 λ=10.0


Training:  70%|█████████████████████         | 70/100 [1:06:00<27:34, 55.14s/it]

[070] NLL: flav=-1.0980 even=-0.9389 odd=-0.9067 | Pen=1.68e-04 Viol=0.1012 λ=10.0


Training:  75%|██████████████████████▌       | 75/100 [1:10:29<22:38, 54.35s/it]

[075] NLL: flav=-1.0972 even=-0.9385 odd=-0.9070 | Pen=2.31e-04 Viol=0.1058 λ=10.0


Training:  80%|████████████████████████      | 80/100 [1:14:58<17:56, 53.82s/it]

[080] NLL: flav=-1.0972 even=-0.9371 odd=-0.9041 | Pen=1.80e-04 Viol=0.1022 λ=10.0


Training:  85%|█████████████████████████▌    | 85/100 [1:19:27<13:37, 54.52s/it]

[085] NLL: flav=-1.1011 even=-0.9431 odd=-0.9101 | Pen=1.57e-04 Viol=0.1036 λ=10.0


Training:  90%|███████████████████████████   | 90/100 [1:24:03<09:11, 55.20s/it]

[090] NLL: flav=-1.1002 even=-0.9434 odd=-0.9096 | Pen=1.22e-04 Viol=0.0952 λ=10.0


Training:  95%|████████████████████████████▌ | 95/100 [1:28:27<04:26, 53.26s/it]

[095] NLL: flav=-1.1022 even=-0.9450 odd=-0.9115 | Pen=1.61e-04 Viol=0.1033 λ=10.0


Training: 100%|█████████████████████████████| 100/100 [1:32:55<00:00, 55.75s/it]

[100] NLL: flav=-1.0992 even=-0.9397 odd=-0.9096 | Pen=3.34e-04 Viol=0.1033 λ=10.0





In [None]:
# Save trained models
output_dir = Path("joint_trained_flows")
output_dir.mkdir(exist_ok=True)

torch.save(flow_flavor_trained.state_dict(), output_dir / "flow_flavor.pth")
torch.save(flow_even_trained.state_dict(), output_dir / "flow_even.pth")
torch.save(flow_odd_trained.state_dict(), output_dir / "flow_odd.pth")

# Save training history
with open(output_dir / "history.json", 'w') as f:
    json.dump(history, f, indent=2)

# Save config
config = {
    **train_config,
    **flow_config,
    'timestamp': datetime.now().isoformat()
}
with open(output_dir / "config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"Models saved to {output_dir}/")

---
## 10. Plot Training History

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# NLL losses
ax = axes[0, 0]
ax.plot(history['nll_flavor'], label='flavor', linewidth=2)
ax.plot(history['nll_even'], label='even', linewidth=2)
ax.plot(history['nll_odd'], label='odd', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('NLL')
ax.set_title('Negative Log-Likelihood')
ax.legend()
ax.grid(True, alpha=0.3)

# Total loss
ax = axes[0, 1]
ax.plot(history['total_loss'], linewidth=2, color='black')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Total Loss')
ax.grid(True, alpha=0.3)

# Penalty loss
ax = axes[0, 2]
ax.plot(history['penalty_loss'], linewidth=2, color='firebrick')
ax.set_xlabel('Epoch')
ax.set_ylabel('Penalty')
ax.set_title('Constraint Penalty')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# Violation fraction
ax = axes[1, 0]
ax.plot(history['violation_frac'], linewidth=2, color='darkorange')
ax.set_xlabel('Epoch')
ax.set_ylabel('Fraction')
ax.set_title('Violation Fraction (|C| > abJ)')
ax.grid(True, alpha=0.3)

# Lambda schedule
ax = axes[1, 1]
ax.plot(history['lambda'], linewidth=2, color='seagreen')
ax.set_xlabel('Epoch')
ax.set_ylabel(r'$\lambda$')
ax.set_title(r'$\lambda$ Schedule')
ax.grid(True, alpha=0.3)

# Learning rates
ax = axes[1, 2]
ax.plot(history['learning_rate_odd'], label='odd', linewidth=2)
ax.plot(history['learning_rate_even'], label='even', linewidth=2)
ax.plot(history['learning_rate_flavor'], label='flavor', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rates')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'training_history.pdf', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFinal NLLs:")
print(f"  Flavor: {history['nll_flavor'][-1]:.6f}")
print(f"  Even:   {history['nll_even'][-1]:.6f}")
print(f"  Odd:    {history['nll_odd'][-1]:.6f}")
print(f"\nFinal violation fraction: {history['violation_frac'][-1]:.6f}")

---
## 11. Evaluate Constraint Violations

In [None]:
def check_violation_rate_on_points(flow_flavor, flow_even, flow_odd,
                                    points, points_swapped,
                                    Gamma_plus, Gamma_minus,
                                    device='cpu', label=""):
    """
    Check what fraction of given points violate |C| > abJ.
    
    Parameters
    ----------
    points : ndarray or Tensor, shape (N, 2)
        SDP coordinates to check
    points_swapped : ndarray or Tensor, shape (N, 2)
        Swapped SDP coordinates
    """
    flow_flavor.eval()
    flow_even.eval()
    flow_odd.eval()
    
    if isinstance(points, np.ndarray):
        pts_t = torch.from_numpy(points.astype(np.float32)).to(device)
        pts_swapped_t = torch.from_numpy(points_swapped.astype(np.float32)).to(device)
    else:
        pts_t = points.to(device)
        pts_swapped_t = points_swapped.to(device)
    
    n_test = len(pts_t)
    
    with torch.no_grad():
        p_even = torch.exp(flow_even.log_prob(pts_t)).cpu().numpy()
        p_odd = torch.exp(flow_odd.log_prob(pts_t)).cpu().numpy()
        
        log_p_flavor_orig = flow_flavor.log_prob(pts_t)
        log_p_flavor_swap = flow_flavor.log_prob(pts_swapped_t)
        abJ = torch.exp(0.5 * (log_p_flavor_orig + log_p_flavor_swap)).cpu().numpy()
    
    denom = Gamma_plus + Gamma_minus
    C = (Gamma_plus * p_even - Gamma_minus * p_odd) / denom
    
    violated = np.abs(C) > abJ
    ratio = np.abs(C) / np.maximum(abJ, 1e-10)
    
    print(f"  {label} ({n_test:,} points):")
    print(f"    Violation rate: {violated.mean():.4f} ({violated.sum():,}/{n_test:,})")
    print(f"    Max |C|/abJ: {ratio.max():.4f}")
    if violated.any():
        print(f"    Mean |C|/abJ (violated only): {ratio[violated].mean():.4f}")
    else:
        print(f"    No violations!")
    
    return C, abJ, violated, ratio

In [None]:
# Evaluate constraint violations on B-decay points
print("=== CONSTRAINT VIOLATIONS ON B-DECAY POINTS ===\n")

# Precompute swapped coordinates for each subset
Bp_swapped = compute_swapped_sdp_coords(Bp_sdp, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2))
Bm_swapped = compute_swapped_sdp_coords(Bm_sdp, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2))
MC_swapped = compute_swapped_sdp_coords(MC_sdp, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2))

# Check each subset
C_Bp, abJ_Bp, viol_Bp, ratio_Bp = check_violation_rate_on_points(
    flow_flavor_trained, flow_even_trained, flow_odd_trained,
    Bp_sdp, Bp_swapped, gamma_p, gamma_m, device=device, label="B+ data"
)

C_Bm, abJ_Bm, viol_Bm, ratio_Bm = check_violation_rate_on_points(
    flow_flavor_trained, flow_even_trained, flow_odd_trained,
    Bm_sdp, Bm_swapped, gamma_p, gamma_m, device=device, label="B- data"
)

C_MC, abJ_MC, viol_MC, ratio_MC = check_violation_rate_on_points(
    flow_flavor_trained, flow_even_trained, flow_odd_trained,
    MC_sdp, MC_swapped, gamma_p, gamma_m, device=device, label="MC"
)

# Combined summary
all_viol = np.concatenate([viol_Bp, viol_Bm, viol_MC])
print(f"\n  COMBINED: {all_viol.mean():.4f} violation rate")

In [None]:
# Visualize violations on B+ and B- data
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, pts, viol, ratio, title in [
    (axes[0], Bp_sdp, viol_Bp, ratio_Bp, "B+ data"),
    (axes[1], Bm_sdp, viol_Bm, ratio_Bm, "B- data"),
    (axes[2], MC_sdp[:50000], viol_MC[:50000], ratio_MC[:50000], "MC (first 50k)")  # Subsample MC for plotting
]:
    # Background: all points in light gray
    ax.scatter(pts[:, 0], pts[:, 1], c='lightgray', s=1, alpha=0.3, rasterized=True)
    
    # Violated points colored by severity
    if viol.any():
        sc = ax.scatter(pts[viol, 0], pts[viol, 1],
                        c=ratio[viol], cmap='hot', s=4, vmin=1.0,
                        vmax=min(ratio[viol].max(), 3.0), rasterized=True)
        plt.colorbar(sc, ax=ax, label=r'$|C|/|A||\bar{A}|$', shrink=0.8)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xlabel(r"$m'$")
    ax.set_ylabel(r"$\theta'$")
    ax.set_title(f"{title} - Violation: {viol.mean():.2%}")

plt.tight_layout()
plt.savefig(output_dir / 'violation_map_B_decay.pdf', dpi=150, bbox_inches='tight')
plt.show()

---
## 13. Ensemble Training with Multiple Independent Datasets

This section generates multiple independent training datasets and runs ensemble training trials.
Each trial uses a fresh dataset to capture statistical variations in the trained flows.

### Workflow:
1. **Generate datasets**: Create N sets of 2M events each for flavor, even, and odd
2. **Train ensembles**: For each dataset, train all three flows jointly
3. **Save results**: Each trial saved to separate directories

In [None]:
# =============================================================================
# 13a. Data Generation Functions
# =============================================================================
from DKpp import DKpp, DKppCorrelated, AmpSample

def generate_flavor_dataset(n_events, sdp_obj, seed=None):
    """
    Generate D → Ksππ flavor-tagged events in SDP coordinates.
    
    Parameters
    ----------
    n_events : int
        Number of events to generate
    sdp_obj : SquareDalitzPlot2
        SDP object for coordinate transformation
    seed : int, optional
        Random seed for reproducibility
    
    Returns
    -------
    ndarray, shape (n_events, 2)
        Events in SDP coordinates (m', θ')
    """
    if seed is not None:
        np.random.seed(seed)
    
    sampler = AmpSample(DKpp())
    points_dp = sampler.generate(n_events, nbatch=50000)
    points_sdp = dp_to_sdp(points_dp, sdp_obj, idx=(1,2,3))
    return points_sdp


def generate_cp_even_dataset(n_events, sdp_obj, seed=None):
    """
    Generate CP-even combination events in SDP coordinates.
    
    Parameters
    ----------
    n_events : int
        Number of events to generate
    sdp_obj : SquareDalitzPlot2
        SDP object for coordinate transformation
    seed : int, optional
        Random seed for reproducibility
    
    Returns
    -------
    ndarray, shape (n_events, 2)
        Events in SDP coordinates (m', θ')
    """
    if seed is not None:
        np.random.seed(seed)
    
    sampler = AmpSample(DKppCorrelated(cp=+1))
    points_dp = sampler.generate(n_events, nbatch=50000)
    points_sdp = dp_to_sdp(points_dp, sdp_obj, idx=(1,2,3))
    return points_sdp


def generate_cp_odd_dataset(n_events, sdp_obj, seed=None):
    """
    Generate CP-odd combination events in SDP coordinates.
    
    Parameters
    ----------
    n_events : int
        Number of events to generate
    sdp_obj : SquareDalitzPlot2
        SDP object for coordinate transformation
    seed : int, optional
        Random seed for reproducibility
    
    Returns
    -------
    ndarray, shape (n_events, 2)
        Events in SDP coordinates (m', θ')
    """
    if seed is not None:
        np.random.seed(seed)
    
    sampler = AmpSample(DKppCorrelated(cp=-1))
    points_dp = sampler.generate(n_events, nbatch=50000)
    points_sdp = dp_to_sdp(points_dp, sdp_obj, idx=(1,2,3))
    return points_sdp


def generate_all_datasets(n_sets, n_events_per_set, sdp_obj, output_dir, base_seed=42):
    """
    Generate multiple independent datasets for ensemble training.
    
    Creates directories:
        output_dir/
            flavor/
                set_0.npy, set_1.npy, ...
            even/
                set_0.npy, set_1.npy, ...
            odd/
                set_0.npy, set_1.npy, ...
    
    Parameters
    ----------
    n_sets : int
        Number of independent datasets to generate
    n_events_per_set : int
        Number of events in each dataset
    sdp_obj : SquareDalitzPlot2
        SDP object for coordinate transformation
    output_dir : str or Path
        Base directory to save datasets
    base_seed : int
        Base random seed (each set uses base_seed + set_index)
    """
    output_dir = Path(output_dir)
    
    # Create subdirectories
    flavor_dir = output_dir / "flavor"
    even_dir = output_dir / "even"
    odd_dir = output_dir / "odd"
    
    flavor_dir.mkdir(parents=True, exist_ok=True)
    even_dir.mkdir(parents=True, exist_ok=True)
    odd_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Generating {n_sets} independent datasets with {n_events_per_set:,} events each")
    print(f"Output directory: {output_dir}")
    print("=" * 60)
    
    for i in range(n_sets):
        seed = base_seed + i * 1000  # Different seed for each set
        print(f"\n[Set {i+1}/{n_sets}] seed={seed}")
        
        # Generate flavor
        print(f"  Generating flavor data...", end=" ", flush=True)
        data_flavor = generate_flavor_dataset(n_events_per_set, sdp_obj, seed=seed)
        np.save(flavor_dir / f"set_{i}.npy", data_flavor)
        print(f"saved {len(data_flavor):,} events")
        
        # Generate CP-even
        print(f"  Generating CP-even data...", end=" ", flush=True)
        data_even = generate_cp_even_dataset(n_events_per_set, sdp_obj, seed=seed+100)
        np.save(even_dir / f"set_{i}.npy", data_even)
        print(f"saved {len(data_even):,} events")
        
        # Generate CP-odd
        print(f"  Generating CP-odd data...", end=" ", flush=True)
        data_odd = generate_cp_odd_dataset(n_events_per_set, sdp_obj, seed=seed+200)
        np.save(odd_dir / f"set_{i}.npy", data_odd)
        print(f"saved {len(data_odd):,} events")
    
    print("\n" + "=" * 60)
    print(f"Done! Generated {n_sets} sets × 3 types = {n_sets * 3} files")
    print(f"Total events: {n_sets * n_events_per_set * 3:,}")
    
    return output_dir

In [None]:
# =============================================================================
# 13b. Ensemble Training Function
# =============================================================================

def run_ensemble_training(
    n_trials,
    data_dir,
    output_base_dir,
    constraint_points,
    constraint_points_swapped,
    Gamma_plus,
    Gamma_minus,
    sdp_obj,
    flow_config=None,
    train_config=None,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    """
    Run ensemble training: train N independent flow triplets on N datasets.
    
    Parameters
    ----------
    n_trials : int
        Number of training trials (should match number of dataset sets)
    data_dir : str or Path
        Directory containing flavor/, even/, odd/ subdirectories with set_*.npy files
    output_base_dir : str or Path
        Base directory for output. Creates:
            output_base_dir/
                ensemble_flavor/trial_0.pth, trial_1.pth, ...
                ensemble_even/trial_0.pth, trial_1.pth, ...
                ensemble_odd/trial_0.pth, trial_1.pth, ...
                histories/trial_0.json, trial_1.json, ...
    constraint_points : torch.Tensor
        B-decay points for constraint evaluation
    constraint_points_swapped : torch.Tensor
        Swapped coordinates for constraint points
    Gamma_plus, Gamma_minus : float
        CP-even/odd decay widths
    sdp_obj : SquareDalitzPlot2
        SDP object for coordinate transformations
    flow_config : dict, optional
        Flow architecture config (num_flows, hidden_features, num_bins)
    train_config : dict, optional
        Training config (num_epochs, warmup_epochs, lr_*, lam, etc.)
    device : str
        Torch device
    
    Returns
    -------
    all_histories : list of dicts
        Training history for each trial
    """
    data_dir = Path(data_dir)
    output_base_dir = Path(output_base_dir)
    
    # Default configs
    if flow_config is None:
        flow_config = {'num_flows': 8, 'hidden_features': 64, 'num_bins': 16}
    if train_config is None:
        train_config = {
            'num_epochs': 100,
            'warmup_epochs': 10,
            'lr_odd': 1e-3,
            'lr_even': 1e-3,
            'lr_flavor': 1e-3,
            'lam': 10.0
        }
    
    # Create output directories
    flavor_out = output_base_dir / "ensemble_flavor"
    even_out = output_base_dir / "ensemble_even"
    odd_out = output_base_dir / "ensemble_odd"
    history_out = output_base_dir / "histories"
    
    flavor_out.mkdir(parents=True, exist_ok=True)
    even_out.mkdir(parents=True, exist_ok=True)
    odd_out.mkdir(parents=True, exist_ok=True)
    history_out.mkdir(parents=True, exist_ok=True)
    
    # Save configs
    config_save = {
        'flow_config': flow_config,
        'train_config': train_config,
        'n_trials': n_trials,
        'data_dir': str(data_dir),
        'timestamp': datetime.now().isoformat()
    }
    with open(output_base_dir / "ensemble_config.json", 'w') as f:
        json.dump(config_save, f, indent=2)
    
    all_histories = []
    
    print("=" * 70)
    print(f"ENSEMBLE TRAINING: {n_trials} trials")
    print(f"Data directory: {data_dir}")
    print(f"Output directory: {output_base_dir}")
    print(f"Flow config: {flow_config}")
    print(f"Train config: {train_config}")
    print("=" * 70)
    
    for trial in range(n_trials):
        print(f"\n{'='*70}")
        print(f"TRIAL {trial + 1}/{n_trials}")
        print(f"{'='*70}")
        
        # Load data for this trial
        print(f"\nLoading dataset set_{trial}...")
        data_flavor = np.load(data_dir / "flavor" / f"set_{trial}.npy")
        data_even = np.load(data_dir / "even" / f"set_{trial}.npy")
        data_odd = np.load(data_dir / "odd" / f"set_{trial}.npy")
        
        print(f"  Flavor: {len(data_flavor):,} events")
        print(f"  Even:   {len(data_even):,} events")
        print(f"  Odd:    {len(data_odd):,} events")
        
        # Precompute swapped coordinates for odd data
        print("Precomputing swapped coordinates for odd data...")
        data_odd_swapped = compute_swapped_sdp_coords(
            data_odd, sdp_obj, idx=(1,2,3), pair_swap=(1,3,2)
        )
        
        # Create dataset and loader
        joint_dataset = JointDalitzDataset(data_flavor, data_even, data_odd, data_odd_swapped)
        joint_loader = DataLoader(
            joint_dataset,
            batch_size=50000,
            shuffle=True,
            num_workers=0,
            pin_memory=(device == "cuda")
        )
        
        # Create fresh flows
        print("Creating new flows...")
        flow_flavor = create_flow(**flow_config, device=device)
        flow_even = create_flow(**flow_config, device=device)
        flow_odd = create_flow(**flow_config, device=device)
        
        # Train
        print(f"\nStarting training...")
        (flow_flavor_t, flow_even_t, flow_odd_t), history = train_joint_flows_with_B_constraint(
            flow_flavor,
            flow_even,
            flow_odd,
            joint_loader,
            constraint_points,
            constraint_points_swapped,
            Gamma_plus=Gamma_plus,
            Gamma_minus=Gamma_minus,
            constraint_batch_size=50000,
            device=device,
            **train_config
        )
        
        # Save models
        torch.save(flow_flavor_t.state_dict(), flavor_out / f"trial_{trial}.pth")
        torch.save(flow_even_t.state_dict(), even_out / f"trial_{trial}.pth")
        torch.save(flow_odd_t.state_dict(), odd_out / f"trial_{trial}.pth")
        
        # Save history
        with open(history_out / f"trial_{trial}.json", 'w') as f:
            json.dump(history, f, indent=2)
        
        all_histories.append(history)
        
        # Print summary for this trial
        print(f"\nTrial {trial + 1} complete:")
        print(f"  Final NLL - Flavor: {history['nll_flavor'][-1]:.4f}")
        print(f"  Final NLL - Even:   {history['nll_even'][-1]:.4f}")
        print(f"  Final NLL - Odd:    {history['nll_odd'][-1]:.4f}")
        print(f"  Final violation:    {history['violation_frac'][-1]:.4f}")
        
        # Clear memory
        del flow_flavor, flow_even, flow_odd
        del flow_flavor_t, flow_even_t, flow_odd_t
        del joint_dataset, joint_loader
        del data_flavor, data_even, data_odd, data_odd_swapped
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    print("\n" + "=" * 70)
    print("ENSEMBLE TRAINING COMPLETE")
    print("=" * 70)
    print(f"\nModels saved to:")
    print(f"  Flavor: {flavor_out}/")
    print(f"  Even:   {even_out}/")
    print(f"  Odd:    {odd_out}/")
    print(f"  Histories: {history_out}/")
    
    # Summary statistics
    final_nlls_flavor = [h['nll_flavor'][-1] for h in all_histories]
    final_nlls_even = [h['nll_even'][-1] for h in all_histories]
    final_nlls_odd = [h['nll_odd'][-1] for h in all_histories]
    final_viols = [h['violation_frac'][-1] for h in all_histories]
    
    print(f"\nSummary statistics across {n_trials} trials:")
    print(f"  NLL Flavor: {np.mean(final_nlls_flavor):.4f} ± {np.std(final_nlls_flavor):.4f}")
    print(f"  NLL Even:   {np.mean(final_nlls_even):.4f} ± {np.std(final_nlls_even):.4f}")
    print(f"  NLL Odd:    {np.mean(final_nlls_odd):.4f} ± {np.std(final_nlls_odd):.4f}")
    print(f"  Violation:  {np.mean(final_viols):.4f} ± {np.std(final_viols):.4f}")
    
    return all_histories

In [None]:
# =============================================================================
# 13c. Generate Datasets (Run this ONCE to create all training data)
# =============================================================================
# Configuration
N_SETS = 10           # Number of independent datasets
N_EVENTS = 2_000_000  # Events per dataset (2M)
DATA_DIR = "ensemble_data_2e6"

# Generate all datasets
generate_all_datasets(
    n_sets=N_SETS,
    n_events_per_set=N_EVENTS,
    sdp_obj=sdp_obj,
    output_dir=DATA_DIR,
    base_seed=42
)

In [None]:
# =============================================================================
# 13d. Run Ensemble Training
# =============================================================================
# Make sure you have:
#   1. Generated datasets (cell 13c above)
#   2. Loaded B-decay constraint points (section 7b)
#   3. Computed Gamma_plus/Gamma_minus (section 5)

# Configuration
N_TRIALS = 10
DATA_DIR = "ensemble_data_2e6"
OUTPUT_DIR = "joint_ensemble_2e6"

# Flow architecture
ENSEMBLE_FLOW_CONFIG = {
    'num_flows': 8,
    'hidden_features': 64,
    'num_bins': 16
}

# Training hyperparameters
ENSEMBLE_TRAIN_CONFIG = {
    'num_epochs': 100,
    'warmup_epochs': 10,
    'lr_odd': 1e-3,
    'lr_even': 1e-3,
    'lr_flavor': 1e-3,
    'lam': 10.0
}

# Run ensemble training
all_histories = run_ensemble_training(
    n_trials=N_TRIALS,
    data_dir=DATA_DIR,
    output_base_dir=OUTPUT_DIR,
    constraint_points=constraint_points,
    constraint_points_swapped=constraint_points_swapped,
    Gamma_plus=gamma_p,
    Gamma_minus=gamma_m,
    sdp_obj=sdp_obj,
    flow_config=ENSEMBLE_FLOW_CONFIG,
    train_config=ENSEMBLE_TRAIN_CONFIG,
    device=device
)

In [None]:
# =============================================================================
# 13e. Plot Ensemble Results
# =============================================================================

def plot_ensemble_histories(all_histories, output_dir=None):
    """Plot training curves for all ensemble trials."""
    n_trials = len(all_histories)
    
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    colors = plt.cm.viridis(np.linspace(0, 1, n_trials))
    
    # NLL Flavor
    ax = axes[0, 0]
    for i, h in enumerate(all_histories):
        ax.plot(h['nll_flavor'], color=colors[i], alpha=0.7, linewidth=1)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('NLL')
    ax.set_title('NLL Flavor (all trials)')
    ax.grid(True, alpha=0.3)
    
    # NLL Even
    ax = axes[0, 1]
    for i, h in enumerate(all_histories):
        ax.plot(h['nll_even'], color=colors[i], alpha=0.7, linewidth=1)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('NLL')
    ax.set_title('NLL Even (all trials)')
    ax.grid(True, alpha=0.3)
    
    # NLL Odd
    ax = axes[0, 2]
    for i, h in enumerate(all_histories):
        ax.plot(h['nll_odd'], color=colors[i], alpha=0.7, linewidth=1)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('NLL')
    ax.set_title('NLL Odd (all trials)')
    ax.grid(True, alpha=0.3)
    
    # Violation fraction
    ax = axes[1, 0]
    for i, h in enumerate(all_histories):
        ax.plot(h['violation_frac'], color=colors[i], alpha=0.7, linewidth=1)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Fraction')
    ax.set_title('Violation Fraction (all trials)')
    ax.grid(True, alpha=0.3)
    
    # Penalty
    ax = axes[1, 1]
    for i, h in enumerate(all_histories):
        ax.plot(h['penalty_loss'], color=colors[i], alpha=0.7, linewidth=1)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Penalty')
    ax.set_title('Constraint Penalty (all trials)')
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)
    
    # Final NLL distribution
    ax = axes[1, 2]
    final_flavor = [h['nll_flavor'][-1] for h in all_histories]
    final_even = [h['nll_even'][-1] for h in all_histories]
    final_odd = [h['nll_odd'][-1] for h in all_histories]
    
    positions = [0, 1, 2]
    ax.boxplot([final_flavor, final_even, final_odd], positions=positions)
    ax.set_xticks(positions)
    ax.set_xticklabels(['Flavor', 'Even', 'Odd'])
    ax.set_ylabel('Final NLL')
    ax.set_title('Final NLL Distribution')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if output_dir:
        plt.savefig(Path(output_dir) / 'ensemble_training_summary.pdf', dpi=150, bbox_inches='tight')
    
    plt.show()
    
    # Print statistics
    print("\n" + "=" * 50)
    print("ENSEMBLE STATISTICS")
    print("=" * 50)
    print(f"\nFinal NLL (mean ± std):")
    print(f"  Flavor: {np.mean(final_flavor):.4f} ± {np.std(final_flavor):.4f}")
    print(f"  Even:   {np.mean(final_even):.4f} ± {np.std(final_even):.4f}")
    print(f"  Odd:    {np.mean(final_odd):.4f} ± {np.std(final_odd):.4f}")
    
    final_viol = [h['violation_frac'][-1] for h in all_histories]
    print(f"\nFinal violation fraction:")
    print(f"  Mean: {np.mean(final_viol):.4f} ± {np.std(final_viol):.4f}")
    print(f"  Range: [{min(final_viol):.4f}, {max(final_viol):.4f}]")

# Plot results
plot_ensemble_histories(all_histories, output_dir=OUTPUT_DIR)

---
## 14. Loading Ensemble Models for Inference

After training, load the ensemble models for γ extraction.

In [None]:
# =============================================================================
# 14. Load Ensemble Models
# =============================================================================

def load_ensemble_flows(ensemble_dir, flow_config, n_trials=None, device='cpu'):
    """
    Load all trained flows from an ensemble directory.
    
    Parameters
    ----------
    ensemble_dir : str or Path
        Directory containing ensemble_flavor/, ensemble_even/, ensemble_odd/
    flow_config : dict
        Flow architecture config used during training
    n_trials : int, optional
        Number of trials to load (default: auto-detect)
    device : str
        Torch device
    
    Returns
    -------
    flows_flavor : list of Flow
    flows_even : list of Flow
    flows_odd : list of Flow
    """
    ensemble_dir = Path(ensemble_dir)
    
    # Auto-detect number of trials
    if n_trials is None:
        flavor_files = list((ensemble_dir / "ensemble_flavor").glob("trial_*.pth"))
        n_trials = len(flavor_files)
    
    print(f"Loading {n_trials} ensemble models from {ensemble_dir}")
    
    flows_flavor = []
    flows_even = []
    flows_odd = []
    
    for i in range(n_trials):
        # Create fresh flows
        f_flavor = create_flow(**flow_config, device=device)
        f_even = create_flow(**flow_config, device=device)
        f_odd = create_flow(**flow_config, device=device)
        
        # Load weights
        f_flavor.load_state_dict(torch.load(
            ensemble_dir / "ensemble_flavor" / f"trial_{i}.pth", 
            map_location=device
        ))
        f_even.load_state_dict(torch.load(
            ensemble_dir / "ensemble_even" / f"trial_{i}.pth",
            map_location=device
        ))
        f_odd.load_state_dict(torch.load(
            ensemble_dir / "ensemble_odd" / f"trial_{i}.pth",
            map_location=device
        ))
        
        # Set to eval mode
        f_flavor.eval()
        f_even.eval()
        f_odd.eval()
        
        flows_flavor.append(f_flavor)
        flows_even.append(f_even)
        flows_odd.append(f_odd)
        
        print(f"  Loaded trial {i}")
    
    print(f"Done! Loaded {n_trials} × 3 = {n_trials * 3} models")
    
    return flows_flavor, flows_even, flows_odd


# Example usage (uncomment to run):
# flows_flavor, flows_even, flows_odd = load_ensemble_flows(
#     ensemble_dir="joint_ensemble_2e6",
#     flow_config=ENSEMBLE_FLOW_CONFIG,
#     device=device
# )