# Optimal Transport Event Shapes: Thrust and Isotropy

This notebook explores the geometric definition of collider event observables using **Optimal Transport (OT)**.

Instead of using traditional formulas (like $\sum_i p_{T,i} |\hat{n} \cdot \hat{p}_i|$ for Thrust), we define these shapes as the **Earth Mover's Distance (EMD)** between an event and a specific **Reference Manifold**.

We will explore two key observables:
1.  **Thrust:** The distance from the event to the manifold of "back-to-back" dijets.
2.  **Event Isotropy:** The distance from the event to the manifold of "isotropic" (uniform) radiation.

This perspective unifies event shapes under a single geometric framework: **How much work is required to rearrange the energy flow of an event into a specific ideal shape?**

In [None]:
# ---------------------------------------------------------
# GOOGLE COLAB SETUP
# ---------------------------------------------------------
import sys
import os
import subprocess
import re

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Detected Google Colab environment.")

    # ---------------------------------------------------------
    # 1. Smart Installation Check (ensure NumPy 1.26.x + SciPy 1.11.x)
    # ---------------------------------------------------------
    def pip_show(pkg):
        try:
            res = subprocess.run([sys.executable, "-m", "pip", "show", pkg], capture_output=True, text=True)
            m = re.search(r"Version:\s*([\w\.\-]+)", res.stdout)
            return m.group(1) if m else "(not installed)"
        except Exception:
            return "(error)"

    current_numpy = pip_show("numpy")
    current_scipy = pip_show("scipy")
    print(f"Detected (pip): numpy={current_numpy}, scipy={current_scipy}")

    # Target versions known-good with POT/Colab
    target_numpy = "1.26.4"
    target_scipy = "1.11.4"

    needs_install = (current_numpy != target_numpy) or (current_scipy != target_scipy)

    # Also ensure key packages are present
    res_pkg = subprocess.run([sys.executable, "-m", "pip", "show", "energyflow"], capture_output=True, text=True)
    if "Name: energyflow" not in res_pkg.stdout:
        needs_install = True

    if needs_install:
        print("Installing compatible stack (this may take a minute)...")
        !pip install -q -U --force-reinstall "numpy==1.26.4" "scipy==1.11.4" energyflow pot watermark h5py
        # Re-check
        current_numpy = pip_show("numpy")
        current_scipy = pip_show("scipy")
        print(f"Post-install (pip): numpy={current_numpy}, scipy={current_scipy}")
        print("\n" + "="*80)
        print("INSTALLATION COMPLETE.")
        print("IMPORTANT: Restart the Colab runtime now:")
        print("Runtime > Restart session, then re-run this setup cell.")
        print("="*80 + "\n")
    else:
        print("Environment appears correct. Skipping installation.")

    # ---------------------------------------------------------
    # 2. Download Dataset
    # ---------------------------------------------------------
    folder_id = '1CJe9xkIk1QmTXJ8g__zagAvn3uoxue5a'
    output_folder = 'isotropy'

    if not os.path.exists(output_folder):
        print(f"Downloading dataset to {output_folder}/...")
        !pip install -q -U --no-cache-dir gdown --pre
        import gdown
        gdown.download_folder(id=folder_id, output=output_folder, quiet=False)
    else:
        print(f"Folder '{output_folder}' already exists. Skipping download.")
else:
    print("Not running in Google Colab. Assuming local environment.")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp

# Guard against ABI mismatches in Colab
print(f"NumPy version: {np.__version__}")
print(f"SciPy version: {sp.__version__}")

# Require the pinned stack (set in the setup cell)
if not np.__version__.startswith("1.26.") or not sp.__version__.startswith("1.11."):
    raise RuntimeError(
        "Incompatible NumPy/SciPy versions detected.\n"
        "Please run the 'GOOGLE COLAB SETUP' cell, then Runtime > Restart session,\n"
        "then re-run the setup cell before proceeding."
    )

import ot  # Python Optimal Transport
import energyflow as ef
import h5py  # For data loading

# Plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

print(f"POT version: {ot.__version__}")
print(f"EnergyFlow version: {ef.__version__}")
print(f"h5py version: {h5py.__version__}")

## Generating Quasi-Uniform Reference Events

To understand these observables, we first create reference events that represent extreme or specific topologies:
1.  **Dijet:** Two particles with equal $p_T$ back-to-back ($\Delta\phi = \pi$).
2.  **TriJet:** Three particles with equal $p_T$ separated by $120^\circ$.
3.  **Isotropic:** Many particles distributed uniformly in $\phi$.

We'll work in the **Transverse Plane** (using only $\phi$) for simplicity, as is common for transverse thrust and ring isotropy.
An event will be represented as a set of weights $w_i$ (normalized $p_T$) and coordinates $x_i$ (so in the transverse plane, just $\phi_i$).

In [None]:
def get_dijet_reference(alpha=0.0):
    """Returns the (phis, weights) of a dijet aligned along alpha."""
    phis = np.array([alpha, alpha + np.pi])
    # Wrap to [0, 2pi]
    phis = np.mod(phis, 2*np.pi)
    weights = np.array([0.5, 0.5])
    return phis, weights

def create_mercedes_event():
    """Creates a symmetric 3-jet event (Mercedes star)."""
    # Three particles at 0, 2pi/3, 4pi/3
    phis = np.array([0.0, 2*np.pi/3, 4*np.pi/3])
    weights = np.array([1/3, 1/3, 1/3])
    return phis, weights

def get_isotropic_reference(n_particles=128):
    """Returns the (phis, weights) of the isotropic reference."""
    # Uniform grid in phi
    phis = np.linspace(0, 2*np.pi, n_particles, endpoint=False)
    weights = np.ones(n_particles) / n_particles
    return phis, weights

# Generate and visualize
ev_dijet = get_dijet_reference()
ev_mercedes = create_mercedes_event()
ev_iso = get_isotropic_reference(n_particles=100)

def plot_event_phi(phis, weights, title):
    ax = plt.subplot(111, projection='polar')
    ax.bar(phis, weights, width=0.1, bottom=0.0, alpha=0.7)
    ax.set_title(title, y=1.1)
    ax.set_yticks([])
    plt.show()

plt.figure(figsize=(15, 4))
plt.subplot(131, projection='polar'); plt.bar(ev_dijet[0], ev_dijet[1], width=0.1); plt.title("Dijet")
plt.subplot(132, projection='polar'); plt.bar(ev_mercedes[0], ev_mercedes[1], width=0.1); plt.title("Mercedes")
plt.subplot(133, projection='polar'); plt.bar(ev_iso[0], ev_iso[1], width=0.1); plt.title("Isotropic")
plt.tight_layout()
plt.show()

## Preprocessing: Alignment

To compute OT-Thrust efficiently, we align the event such that its principal axis corresponds to $\phi=0$.
We calculate the **Thrust Axis** $\hat{n}_T$ which maximizes the projected momentum:
$$ \hat{n}_T = \arg \max_{\hat{n}} \sum_i p_{T,i} |\hat{n} \cdot \hat{p}_{T,i}| $$
Once found, we rotate the event so $\hat{n}_T$ points to $\phi=0$.
This simplifies the OT problem: we only need to compute the distance to the *fixed* dijet reference at $\alpha=0$.

In [None]:
from scipy.optimize import minimize_scalar

def get_thrust_axis(phis, weights):
    """Finds the angle alpha that maximizes sum(w_i * |cos(phi_i - alpha)|)."""
    def objective(alpha):
        # We want to MAXIMIZE sum, so we MINIMIZE negative sum
        val = np.sum(weights * np.abs(np.cos(phis - alpha)))
        return -val
    
    # Brute force coarse search to avoid local minima
    test_alphas = np.linspace(0, np.pi, 20)
    vals = [objective(a) for a in test_alphas]
    best_alpha_init = test_alphas[np.argmin(vals)]
    
    # Refine
    res = minimize_scalar(objective, bracket=(best_alpha_init - 0.2, best_alpha_init + 0.2))
    return res.x

def align_event(phis, weights):
    """Rotates the event so the thrust axis is at phi=0."""
    alpha = get_thrust_axis(phis, weights)
    phis_aligned = phis - alpha
    # Wrap to [-pi, pi] for plotting/consistency
    phis_aligned = (phis_aligned + np.pi) % (2 * np.pi) - np.pi
    return phis_aligned, weights

print("Alignment function defined.")

## Event Shape 1: Thrust

We define **Thrust** as the EMD between the aligned event and the fixed Dijet Reference ($\phi \in \{0, \pi\}$).
$$ \tau_{OT} = \text{EMD}(E_{aligned}, \mathcal{D}_0) $$

**Metric:**
Following the standard Event Isotropy definition (and the provided example), we use a distance based on the cosine of the angle difference (related to the chordal distance):
$$ d(\phi_1, \phi_2) = (1 - \cos(\phi_1 - \phi_2))^\beta $$
We typically use $\beta=1$ or $\beta=2$.

*   **Low Value ($\approx 0$):** Event is very close to a dijet (pencil-like).
*   **High Value:** Event is far from a dijet (spherical).

In [None]:
def ring_distance(phi1, phi2, beta=2.0):
    """
    Computes the distance between angles on a circle.
    Metric: d = (1 - cos(phi1 - phi2))**beta
    """
    # cos(a-b) handles the periodicity naturally
    d = (1.0 - np.cos(phi1 - phi2))**beta
    return d

def compute_ot_thrust(phis, weights, beta=2.0):
    """Computes OT-Thrust (distance to dijet)."""
    # 1. Align
    phis_aligned, w_aligned = align_event(phis, weights)
    
    # 2. Get Reference
    phis_ref, w_ref = get_dijet_reference(alpha=0.0)
    
    # 3. Compute Cost Matrix
    # M[i, j] = dist(event[i], ref[j])
    M = ring_distance(phis_aligned[:, None], phis_ref[None, :], beta=beta)
    
    # 4. Compute EMD
    val = ot.emd2(w_aligned, w_ref, M)
    return val

# Test
t_dijet = compute_ot_thrust(*ev_dijet)
t_mercedes = compute_ot_thrust(*ev_mercedes)
t_iso = compute_ot_thrust(*ev_iso)

print(f"OT-Thrust (Dijet):    {t_dijet:.4f} (Expected ~0)")
print(f"OT-Thrust (Mercedes): {t_mercedes:.4f}")
print(f"OT-Thrust (Isotropic):{t_iso:.4f}")

## Event Shape 2: Event Isotropy

**Event Isotropy** $\mathcal{I}$ is the distance to the Isotropic Reference (Uniform Ring).
$$ \mathcal{I} = \text{EMD}(E, \mathcal{U}) $$

**Rotational Invariance:**
Since the reference is a uniform ring, the distance *should* be invariant under rotation. However, since we approximate the ring with $N$ discrete points, we must **minimize the EMD over the rotation angle** (shift) of the reference grid to ensure the result is robust and truly minimal.

*   **Low Value ($\approx 0$):** Event is isotropic (spherical).
*   **High Value:** Event is far from isotropic (e.g., dijet).

In [None]:
from scipy.optimize import minimize_scalar

def compute_isotropy(phis, weights, n_ref=128, beta=1.0):
    """
    Computes Event Isotropy (distance to uniform ring).
    Minimizes EMD over rotations of the reference ring.
    """
    # 1. Get Reference (centered at 0)
    phis_ref_0, w_ref = get_isotropic_reference(n_particles=n_ref)
    
    # Define objective function for rotation
    def objective(shift):
        # Rotate reference by shift
        phis_ref_shifted = phis_ref_0 + shift
        
        # Compute Cost Matrix
        M = ring_distance(phis[:, None], phis_ref_shifted[None, :], beta=beta)
        
        # Compute EMD
        return ot.emd2(weights, w_ref, M)

    # Minimize over shift
    # The reference has N-fold symmetry, so we only need to check [0, 2pi/N]
    # We add a small buffer to be safe
    bounds = (0, 2*np.pi / n_ref)
    
    # Use bounded minimization
    res = minimize_scalar(objective, bounds=bounds, method='bounded')
    
    return res.fun

# Test
i_dijet = compute_isotropy(*ev_dijet)
i_mercedes = compute_isotropy(*ev_mercedes)
i_iso = compute_isotropy(*ev_iso)

print(f"Isotropy (Dijet):    {i_dijet:.4f} (Expected High)")
print(f"Isotropy (Mercedes): {i_mercedes:.4f}")
print(f"Isotropy (Isotropic):{i_iso:.4f} (Expected ~0)")

## Visualizing the Geometry

To understand *why* an event has a certain shape value, we visualize the **Optimal Transport Plan**.
The lines connect the event particles (Source) to the reference geometry (Target).
The total "cost" (length of lines $\times$ weight moved) is the shape value.

In [None]:
def plot_transport_plan(phis_ev, w_ev, phis_ref, w_ref, title="Transport Plan", beta=1.0):
    """Visualizes the OT plan on a circle using arcs at fixed radius."""
    # Compute Plan using the correct metric
    M = ring_distance(phis_ev[:, None], phis_ref[None, :], beta=beta)
    G = ot.emd(w_ev, w_ref, M)

    plt.figure(figsize=(6, 6))
    ax = plt.subplot(111, projection='polar')

    # Fixed radius for 1D OT on circle
    R = 1.0

    # Plot Source (Event) - Red
    # We use slightly different radii or alpha to distinguish if they overlap perfectly
    ax.scatter(phis_ev, np.ones_like(phis_ev)*R, s=w_ev*1000, c='r', label='Event', alpha=0.7, zorder=3)

    # Plot Target (Reference) - Blue
    ax.scatter(phis_ref, np.ones_like(phis_ref)*R, s=w_ref*1000, c='b', label='Reference', alpha=0.4, zorder=2)

    # Plot Lines (Arcs along the ring)
    threshold = G.max() * 0.01
    for i in range(len(phis_ev)):
        for j in range(len(phis_ref)):
            if G[i, j] > threshold:
                # Determine shortest angular path
                phi_start = phis_ev[i]
                phi_end = phis_ref[j]
                diff = phi_end - phi_start

                # Wrap diff to [-pi, pi] to find shortest path
                diff = (diff + np.pi) % (2*np.pi) - np.pi

                # Interpolate points for the arc
                t = np.linspace(0, 1, 50)
                p_interp = phi_start + t * diff
                r_interp = np.ones_like(p_interp) * R

                # Color the arc based on flow intensity
                ax.plot(p_interp, r_interp, 'k-', alpha=0.3, lw=G[i, j]*50)

    ax.set_ylim(0, 1.1)
    ax.set_yticks([])
    ax.set_title(title)
    plt.legend(loc='lower right', bbox_to_anchor=(1.3, 0))
    plt.show()

# Visualize Thrust (Event -> Dijet)
phis_aligned, w_aligned = align_event(ev_mercedes[0], ev_mercedes[1])
phis_ref_dijet, w_ref_dijet = get_dijet_reference()
plot_transport_plan(phis_aligned, w_aligned, phis_ref_dijet, w_ref_dijet, title="OT-Thrust Plan (Mercedes -> Dijet)")

# Visualize Isotropy (Event -> Ring)
def get_optimal_shift(phis, weights, n_ref=32, beta=1.0):
    phis_ref_0, w_ref = get_isotropic_reference(n_particles=n_ref)
    def objective(shift):
        phis_ref_shifted = phis_ref_0 + shift
        M = ring_distance(phis[:, None], phis_ref_shifted[None, :], beta=beta)
        return ot.emd2(weights, w_ref, M)
    res = minimize_scalar(objective, bounds=(0, 2*np.pi/n_ref), method='bounded')
    return res.x

shift_opt = get_optimal_shift(ev_dijet[0], ev_dijet[1], n_ref=32)
phis_ref_iso, w_ref_iso = get_isotropic_reference(32)
phis_ref_iso_shifted = phis_ref_iso + shift_opt

plot_transport_plan(ev_dijet[0], ev_dijet[1], phis_ref_iso_shifted, w_ref_iso, title="Isotropy Plan (Dijet -> Ring)")

## Analysis: Correlations

We now explore the correlation between these two observables.
We will generate a large batch of random events (N-body phase space) and plot them in the **(Thrust, Isotropy)** plane.

*   **Thrust Axis (x):** How "dijet-like" is it?
*   **Isotropy Axis (y):** How "spherical" is it?

In [None]:
def load_pythia_events_robust(filepath, n_events=100):
    """
    Loads events from HDF5 file with robust error handling and debugging.
    """
    print(f"Attempting to load from: {filepath}")
    loaded_events = []
    
    try:
        with h5py.File(filepath, 'r') as f:
            print(f"Opened file: {filepath}")
            
            if 'events' not in f:
                print("Error: 'events' group not found in file.")
                print(f"Keys found: {list(f.keys())}")
                return []
                
            events_group = f['events']
            all_keys = list(events_group.keys())
            n_found = len(all_keys)
            print(f"Found {n_found} events in file.")
            
            # Sort keys to ensure deterministic loading order
            try:
                all_keys.sort(key=lambda x: int(x.split('_')[1]) if '_' in x else x)
            except ValueError:
                all_keys.sort()

            keys_to_load = all_keys[:n_events]
            
            for i, k in enumerate(keys_to_load):
                event_group = events_group[k]
                
                # Try to load particles from 'particles' or 'jets'
                data_source = None
                if 'particles' in event_group:
                    data_source = event_group['particles']
                elif 'jets' in event_group:
                    data_source = event_group['jets']
                
                if data_source is not None:
                    data = data_source[:]
                    
                    if data.shape[0] == 0:
                        continue

                    # Extract pt and phi based on shape
                    try:
                        if data.shape[1] >= 8:
                            # Standard Pythia: [px, py, pz, e, id, pt, eta, phi]
                            pts = data[:, 5]
                            phis = data[:, 7]
                        elif data.shape[1] == 4:
                            # Jets: [pt, eta, phi, m]
                            pts = data[:, 0]
                            phis = data[:, 2]
                        else:
                            if i < 5: print(f"Skipping event {k}: Unknown shape {data.shape}")
                            continue
                            
                        # Filter by pT > 0.1 GeV
                        mask = pts > 0.1
                        pts = pts[mask]
                        phis = phis[mask]
                        
                        # --- MODIFICATION: Add Recoil Vector ---
                        # Calculate net momentum
                        if len(pts) > 0:
                            px = np.sum(pts * np.cos(phis))
                            py = np.sum(pts * np.sin(phis))
                            
                            # Recoil is negative sum
                            recoil_pt = np.sqrt(px**2 + py**2)
                            recoil_phi = np.arctan2(-py, -px)
                            
                            # Add recoil particle to arrays
                            pts = np.append(pts, recoil_pt)
                            phis = np.append(phis, recoil_phi)
                        # ---------------------------------------
                        
                        if len(pts) >= 2:
                            weights = pts / np.sum(pts)
                            loaded_events.append((phis, weights))
                        else:
                            if i < 5: print(f"Skipping event {k}: < 2 objects after cuts")
                            
                    except IndexError:
                        if i < 5: print(f"Skipping event {k}: Index error accessing columns")
                        continue
                else:
                    if i < 5: print(f"Skipping event {k}: No suitable data found")

    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()
        return []
        
    print(f"Successfully loaded {len(loaded_events)} events.")
    return loaded_events

# Resolve dataset path robustly (handles gdown nested folders)
import os
from glob import glob

def resolve_isotropy_file(name='background_events.h5'):
    candidates = [
        os.path.join('isotropy', name),
        name,
    ]
    # Also search recursively under isotropy/ in case gdown created a subfolder
    candidates.extend(sorted(glob(os.path.join('isotropy', '**', name), recursive=True)))
    for p in candidates:
        if os.path.exists(p):
            print(f"Using file: {p}")
            return p
    raise FileNotFoundError(
        f"{name} not found under 'isotropy/'. If running in Colab, run the setup cell to download."
    )

# Load real events
bg_filepath = resolve_isotropy_file('background_events.h5')
real_events = load_pythia_events_robust(bg_filepath, n_events=200000)

if len(real_events) > 0:
    print(f"Example event 0: {len(real_events[0][0])} particles")
    
    # Visualize the first real event
    phis_ev, w_ev = real_events[0]
    
    # Compute Isotropy for this event
    shift_opt = get_optimal_shift(phis_ev, w_ev, n_ref=32)
    phis_ref_iso, w_ref_iso = get_isotropic_reference(32)
    phis_ref_iso_shifted = phis_ref_iso + shift_opt
    
    plot_transport_plan(phis_ev, w_ev, phis_ref_iso_shifted, w_ref_iso, title="Isotropy Plan (Real Event -> Ring)")
    
else:
    print("No events loaded. Please check the file path and content.")

In [None]:
# Compute observables for a batch of real events
n_process = min(50000, len(real_events))
isotropies = []
thrusts = []

print(f"Processing {n_process} events...")

# Pre-compute references
phis_ref_iso, w_ref_iso = get_isotropic_reference(256)
phis_ref_dijet, w_ref_dijet = get_dijet_reference()

# --- Calculate Normalization Constants ---
# 1. Isotropy Normalization: Distance of a Dijet from Isotropic
# Create a perfect dijet
dijet_phis, dijet_weights = get_dijet_reference()
# Compute its raw isotropy (distance to ring)
def get_raw_isotropy(phis, weights):
    def obj(shift):
        phis_ref_shifted = phis_ref_iso + shift
        M = ring_distance(phis[:, None], phis_ref_shifted[None, :], beta=2.0)
        return ot.emd2(weights, w_ref_iso, M)
    res = minimize_scalar(obj, bounds=(0, 2*np.pi/32), method='bounded')
    return res.fun

norm_iso = get_raw_isotropy(dijet_phis, dijet_weights)
print(f"Normalization Constant (Isotropy - Dijet): {norm_iso:.4f}")

# 2. Thrust Normalization: Distance of a Mercedes from Dijet
# Create a perfect mercedes
merc_phis, merc_weights = create_mercedes_event()
# Compute its raw thrust (distance to dijet)
def get_raw_thrust(phis, weights):
    # Align first
    phis_aligned, w_aligned = align_event(phis, weights)
    # Optimize rotation to dijet
    def obj(shift):
        phis_ref_shifted = phis_ref_dijet + shift
        M = ring_distance(phis_aligned[:, None], phis_ref_shifted[None, :], beta=2.0)
        return ot.emd2(w_aligned, w_ref_dijet, M)
    res = minimize_scalar(obj, bounds=(0, 2*np.pi), method='bounded')
    return res.fun

norm_thrust = get_raw_thrust(merc_phis, merc_weights)
print(f"Normalization Constant (Thrust - Mercedes): {norm_thrust:.4f}")
# -----------------------------------------

for i in range(n_process):
    phis_ev, w_ev = real_events[i]
    
    # 1. Isotropy (OT distance to uniform ring)
    def objective_iso(shift):
        phis_ref_shifted = phis_ref_iso + shift
        M = ring_distance(phis_ev[:, None], phis_ref_shifted[None, :], beta=2.0)
        return ot.emd2(w_ev, w_ref_iso, M)
    
    res_iso = minimize_scalar(objective_iso, bounds=(0, 2*np.pi/32), method='bounded')
    isotropies.append(res_iso.fun / norm_iso) # Normalize
    
    # 2. Thrust (OT distance to dijet)
    def objective_thrust(shift):
        phis_ref_shifted = phis_ref_dijet + shift
        M = ring_distance(phis_ev[:, None], phis_ref_shifted[None, :], beta=2.0)
        return ot.emd2(w_ev, w_ref_dijet, M)
        
    res_thrust = minimize_scalar(objective_thrust, bounds=(0, 2*np.pi), method='bounded')
    thrusts.append(res_thrust.fun / norm_thrust) # Normalize

# Plot distributions (Restricted to 0-1 range as requested)
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.hist(isotropies, bins=20, alpha=0.7, color='blue', edgecolor='black', range=(0, 1))
plt.xlabel('Isotropy (Normalized)')
plt.title('Isotropy Distribution')
plt.xlim(0, 1)

plt.subplot(1, 3, 2)
plt.hist(thrusts, bins=20, alpha=0.7, color='red', edgecolor='black', range=(0, 1))
plt.xlabel('Thrust-like (Normalized)')
plt.title('Thrust-like Distribution')
plt.xlim(0, 1)

plt.subplot(1, 3, 3)
plt.hist2d(thrusts, isotropies, bins=20, cmap='viridis', range=[[0, 1], [0, 1]])
plt.xlabel('Thrust-like (Normalized)')
plt.ylabel('Isotropy (Normalized)')
plt.title('Correlation')
plt.colorbar(label='Counts')

plt.tight_layout()
plt.show() 

isotropies_arr = np.array(isotropies)
thrusts_arr = np.array(thrusts)

In [None]:
from matplotlib import animation
from IPython.display import HTML

def animate_transport(phis_source, w_source, phis_target, w_target, title="Transport Animation", beta=2.0, frames=60):
    """
    Animates the optimal transport of particles from source to target on the circle.
    """
    # 1. Compute OT Plan
    M = ring_distance(phis_source[:, None], phis_target[None, :], beta=beta)
    G = ot.emd(w_source, w_target, M)
    
    # 2. Prepare paths
    # Decompose transport into sub-particles moving from i to j
    paths = [] 
    threshold = G.max() * 0.001
    for i in range(len(phis_source)):
        for j in range(len(phis_target)):
            if G[i, j] > threshold:
                start = phis_source[i]
                end = phis_target[j]
                weight = G[i, j]
                
                # Shortest angular path
                diff = end - start
                diff = (diff + np.pi) % (2*np.pi) - np.pi
                real_end = start + diff
                
                paths.append({
                    'start': start,
                    'end': real_end,
                    'weight': weight
                })
                
    # 3. Setup Plot
    fig = plt.figure(figsize=(6, 6))
    ax = plt.subplot(111, projection='polar')
    ax.set_ylim(0, 1.1)
    ax.set_yticks([])
    ax.set_title(title)
    
    # Fixed radius
    R = 1.0
    
    # Initial scatter plot (empty)
    scat = ax.scatter([], [], c='purple', alpha=0.6, zorder=3)
    
    # Background: Target distribution (faint)
    ax.scatter(phis_target, np.ones_like(phis_target)*R, s=w_target*1000, c='blue', alpha=0.1, label='Target')
    
    def init():
        scat.set_offsets(np.empty((0, 2)))
        return (scat,)

    def update(frame):
        t = frame / (frames - 1) # Progress 0 -> 1
        
        current_phis = []
        current_sizes = []
        
        for p in paths:
            # Linear interpolation of angle
            phi_t = p['start'] + t * (p['end'] - p['start'])
            current_phis.append(phi_t)
            current_sizes.append(p['weight'] * 1000) # Scale size by weight
            
        # Update scatter
        # Polar plot expects (theta, r)
        data = np.column_stack([current_phis, np.ones(len(current_phis))*R])
        scat.set_offsets(data)
        scat.set_sizes(current_sizes)
        
        # Color interpolation (Red -> Blue)
        # Red (1,0,0) to Blue (0,0,1)
        color = (1-t, 0, t)
        scat.set_color(color)
        
        return (scat,)

    anim = animation.FuncAnimation(fig, update, init_func=init, frames=frames, interval=50, blit=True)
    plt.close() # Prevent static plot from showing
    return anim

# --- Generate Animations for a Sample Event ---
if len(real_events) > 0:
    # Pick an interesting event (e.g., one with high isotropy or thrust)
    # Let's just pick the first one for now
    sample_idx = 0
    phis_ev, w_ev = real_events[sample_idx]
    
    print(f"Animating Event {sample_idx}...")

    # 1. Animate Isotropy (Event -> Ring)
    # Optimize rotation first
    def obj_iso(shift):
        phis_ref_shifted = phis_ref_iso + shift
        M = ring_distance(phis_ev[:, None], phis_ref_shifted[None, :], beta=2.0)
        return ot.emd2(w_ev, w_ref_iso, M)
    res_iso = minimize_scalar(obj_iso, bounds=(0, 2*np.pi/32), method='bounded')
    phis_ref_iso_opt = phis_ref_iso + res_iso.x
    
    anim_iso = animate_transport(phis_ev, w_ev, phis_ref_iso_opt, w_ref_iso, 
                                 title=f"Isotropy Transport (Event {sample_idx})")
    
    # 2. Animate Thrust (Event -> Dijet)
    # Align first
    phis_aligned, w_aligned = align_event(phis_ev, w_ev)
    # Optimize rotation
    def obj_thrust(shift):
        phis_ref_shifted = phis_ref_dijet + shift
        M = ring_distance(phis_aligned[:, None], phis_ref_shifted[None, :], beta=2.0)
        return ot.emd2(w_aligned, w_ref_dijet, M)
    res_thrust = minimize_scalar(obj_thrust, bounds=(0, 2*np.pi), method='bounded')
    phis_ref_dijet_opt = phis_ref_dijet + res_thrust.x
    
    anim_thrust = animate_transport(phis_aligned, w_aligned, phis_ref_dijet_opt, w_ref_dijet, 
                                    title=f"Thrust Transport (Event {sample_idx})")

    # Display
    print("Isotropy Animation:")
    display(HTML(anim_iso.to_jshtml()))
    print("Thrust Animation:")
    display(HTML(anim_thrust.to_jshtml()))
else:
    print("No events loaded to animate.")

In [None]:
# Most and least isotropic events

if len(isotropies_arr) > 0:
    # Find indices
    idx_most_iso = np.argmin(isotropies_arr)
    idx_least_iso = np.argmax(isotropies_arr)
    
    print(f"Most Isotropic Event: Index {idx_most_iso}, Value: {isotropies_arr[idx_most_iso]:.4f}")
    print(f"Least Isotropic Event: Index {idx_least_iso}, Value: {isotropies_arr[idx_least_iso]:.4f}")
    
    # 1. Plot Most Isotropic
    phis_ev, w_ev = real_events[idx_most_iso]
    # Re-optimize shift for visualization
    shift_opt = get_optimal_shift(phis_ev, w_ev, n_ref=256, beta=2.0)
    phis_ref_iso, w_ref_iso = get_isotropic_reference(256)
    phis_ref_iso_shifted = phis_ref_iso + shift_opt
    
    plot_transport_plan(phis_ev, w_ev, phis_ref_iso_shifted, w_ref_iso, 
                        title=f"Most Isotropic Event (Idx {idx_most_iso})", beta=2.0)
                        
    # 2. Plot Least Isotropic
    phis_ev, w_ev = real_events[idx_least_iso]
    # Re-optimize shift for visualization
    shift_opt = get_optimal_shift(phis_ev, w_ev, n_ref=256, beta=2.0)
    phis_ref_iso, w_ref_iso = get_isotropic_reference(256)
    phis_ref_iso_shifted = phis_ref_iso + shift_opt
    
    plot_transport_plan(phis_ev, w_ev, phis_ref_iso_shifted, w_ref_iso, 
                        title=f"Least Isotropic Event (Idx {idx_least_iso})", beta=2.0)