In [None]:
# Many studies in this notebook based on 
# https://arxiv.org/abs/1902.02346
# The Metric Space of Collider Events
# Patrick T. Komiske, Eric M. Metodiev, Jesse Thaler

# ---------------------------------------------------------
# GOOGLE COLAB SETUP
# ---------------------------------------------------------
import sys
import os
import subprocess
import re

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Detected Google Colab environment.")
    
    # ---------------------------------------------------------
    # 1. Smart Installation Check
    # ---------------------------------------------------------
    # We check the installed NumPy version via pip *before* deciding to install.
    # This prevents the cell from hanging on re-runs after a restart.
    
    current_numpy_version = "0.0.0"
    try:
        # Run 'pip show numpy' to get the version on disk
        res = subprocess.run([sys.executable, "-m", "pip", "show", "numpy"], capture_output=True, text=True)
        m = re.search(r"Version:\s*(\d+\.\d+\.\d+)", res.stdout)
        if m:
            current_numpy_version = m.group(1)
    except Exception:
        pass

    print(f"Current NumPy version on disk: {current_numpy_version}")
    
    # Determine if we need to install/downgrade
    needs_install = False
    
    # Condition 1: NumPy must be 1.x
    if not current_numpy_version.startswith("1."):
        print("NumPy 2.x (or unknown) detected. Downgrade required.")
        needs_install = True
    else:
        # Condition 2: Check if energyflow/geomloss are installed
        res_pkg = subprocess.run([sys.executable, "-m", "pip", "show", "energyflow"], capture_output=True, text=True)
        if "Name: energyflow" not in res_pkg.stdout:
            print("energyflow not found. Installation required.")
            needs_install = True

    if needs_install:
        print("Installing required packages (this may take a minute)...")
        # AGGRESSIVE COMPATIBILITY FIX:
        # Force reinstall to ensure we get numpy<2.0.0 and compatible scipy/sklearn
        # Added geomloss for GPU OT
        !pip install -q -U --force-reinstall "numpy<2.0.0" "scipy<1.13.0" "scikit-learn<1.5.0" energyflow pot uproot awkward vector geomloss
        
        print("\n" + "="*80)
        print("INSTALLATION COMPLETE.")
        print("CRITICAL: You MUST restart the Colab runtime now.")
        print("1. Go to menu: Runtime > Restart session")
        print("2. Run this cell again (it will skip installation next time)")
        print("="*80 + "\n")
    else:
        print("Environment appears correct (NumPy 1.x installed). Skipping installation.")

    # ---------------------------------------------------------
    # 2. Download Dataset
    # ---------------------------------------------------------
    folder_id = '1CJe9xkIk1QmTXJ8g__zagAvn3uoxue5a'
    output_folder = 'efjets'
    
    if not os.path.exists(output_folder):
        print(f"Downloading dataset to {output_folder}/...")
        !pip install -q -U --no-cache-dir gdown --pre
        import gdown
        gdown.download_folder(id=folder_id, output=output_folder, quiet=False)
    else:
        print(f"Folder '{output_folder}' already exists. Skipping download.")
        
else:
    print("Not running in Google Colab. Assuming local environment.")

In [None]:
# Setup and Imports
import numpy as np
import matplotlib.pyplot as plt
import glob
import os

# CRITICAL: Check NumPy version before importing other libraries
print(f"NumPy version: {np.__version__}")
if np.__version__ >= '2.0.0':
    msg = (
        f"Detected NumPy {np.__version__}, but this notebook requires NumPy < 2.0.0.\n"
        "Please run the 'GOOGLE COLAB SETUP' cell above to install the correct versions,\n"
        "then RESTART the runtime (Runtime > Restart session) and run this cell again."
    )
    raise RuntimeError(msg)

import ot
import energyflow as ef
from sklearn.manifold import TSNE

print("Libraries imported successfully.")

## 1. Data Loading: EnergyFlow Datasets

We will load jets from the `efjets/` directory.
*   **Quark Jets:** From `QG_jets` files (Label 1).
*   **Gluon Jets:** From `QG_jets` files (Label 0).
*   **Top Jets:** From `top_qcd` files (Label 1).

The data format is `(N, M, 4)` where features are `(pt, y, phi, pid)`.

In [None]:
# Data Loading Logic

# Define classes for our analysis
# We will map original labels to these new IDs
# 0: Gluon
# 1: Quark
# 2: Top
class_names = {
    0: 'Gluon',
    1: 'Quark',
    2: 'Top'
}

def preprocess_jets(X, y, source_type, max_jets=1000, max_particles=128):
    """
    Convert (N, M, 4) array to list of (M, 2) coordinates and weights.
    source_type: 'QG' or 'Top' to handle label mapping.
    """
    jets_X = []
    jets_w = []
    jets_pt = []
    labels = []
    
    count = 0
    for i in range(len(X)):
        if count >= max_jets:
            break
            
        # Extract features: (pt, y, phi, pid)
        # We use y (rapidity) and phi
        jet_data = X[i]
        
        # Filter zero-padded particles (pt=0)
        mask = jet_data[:, 0] > 0
        if np.sum(mask) < 2: continue # Skip empty/single particle jets
        
        pts = jet_data[mask, 0]
        ys  = jet_data[mask, 1]
        phis = jet_data[mask, 2]
        
        # Store total pT before normalization
        total_pt = pts.sum()
        
        # Limit particles
        if len(pts) > max_particles:
            # Sort by pt descending
            idx = np.argsort(pts)[::-1][:max_particles]
            pts = pts[idx]
            ys = ys[idx]
            phis = phis[idx]
            
        # Centering
        # Note: phi periodicity is handled by centering if the jet is localized
        # For safety, we can re-wrap phi after centering, but usually not strictly necessary for small R jets
        y_avg = np.average(ys, weights=pts)
        
        # Circular mean for phi
        phi_avg = np.arctan2(np.average(np.sin(phis), weights=pts), 
                             np.average(np.cos(phis), weights=pts))
                             
        ys_centered = ys - y_avg
        phis_centered = phis - phi_avg
        
        # Wrap phi to [-pi, pi]
        phis_centered = (phis_centered + np.pi) % (2 * np.pi) - np.pi
        
        # Construct (N, 2) array
        X_i = np.stack([ys_centered, phis_centered], axis=1)
        
        # Weights (normalized pT)
        w_i = pts / pts.sum()
        
        # Determine Label
        original_label = int(y[i])
        new_label = -1
        
        if source_type == 'QG':
            # Original: 0=Gluon, 1=Quark
            # Map: 0->0 (Gluon), 1->1 (Quark)
            new_label = original_label
        elif source_type == 'Top':
            # Original: 0=QCD, 1=Top
            # We only want Top (1)
            if original_label == 1:
                new_label = 2 # Top
            else:
                continue # Skip QCD from this file
        
        jets_X.append(X_i)
        jets_w.append(w_i)
        jets_pt.append(total_pt)
        labels.append(new_label)
        count += 1
        
    return jets_X, jets_w, labels, jets_pt

def load_ef_data(max_jets_per_class=250):
    all_X = []
    all_w = []
    all_labels = []
    all_pt = []
    
    # 1. Load QG Jets
    qg_files = glob.glob('efjets/QG_jets_*.npz')
    print(f"Found {len(qg_files)} QG files.")
    
    jets_needed = max_jets_per_class * 2 # Gluon + Quark
    jets_loaded = 0
    
    for f in qg_files:
        if jets_loaded >= jets_needed: break
        print(f"Loading {f}...")
        data = np.load(f)
        # QG files use 'X' and 'y'
        X_raw = data['X']
        y_raw = data['y'] # 0=Gluon, 1=Quark
        
        # We want roughly equal mix, so let's just load and filter later or load enough
        # Simple approach: Load chunk, process
        X_proc, w_proc, l_proc, pt_proc = preprocess_jets(X_raw, y_raw, 'QG', max_jets=jets_needed)
        
        all_X.extend(X_proc)
        all_w.extend(w_proc)
        all_labels.extend(l_proc)
        all_pt.extend(pt_proc)
        jets_loaded += len(X_proc)

    # 2. Load Top Jets
    top_files = glob.glob('efjets/top_qcd_*.npz')
    print(f"Found {len(top_files)} Top/QCD files.")
    
    jets_needed_top = max_jets_per_class
    jets_loaded_top = 0
    
    for f in top_files:
        if jets_loaded_top >= jets_needed_top: break
        print(f"Loading {f}...")
        data = np.load(f)
        
        # FIX: Check for keys, as top_qcd files might use 'data'/'labels'
        if 'data' in data:
            X_raw = data['data']
            y_raw = data['labels']
        else:
            X_raw = data['X']
            y_raw = data['y']
        
        X_proc, w_proc, l_proc, pt_proc = preprocess_jets(X_raw, y_raw, 'Top', max_jets=jets_needed_top*4) # Load more to find Tops
        
        all_X.extend(X_proc)
        all_w.extend(w_proc)
        all_labels.extend(l_proc)
        all_pt.extend(pt_proc)
        jets_loaded_top += len(l_proc)

    # Subsample to balance classes exactly
    final_X = []
    final_w = []
    final_labels = []
    final_pt = []
    
    arr_labels = np.array(all_labels)
    
    for cls_id in [0, 1, 2]:
        indices = np.where(arr_labels == cls_id)[0]
        if len(indices) > max_jets_per_class:
            indices = np.random.choice(indices, max_jets_per_class, replace=False)
        
        print(f"Class {class_names[cls_id]}: {len(indices)} jets")
        
        for idx in indices:
            final_X.append(all_X[idx])
            final_w.append(all_w[idx])
            final_labels.append(all_labels[idx])
            final_pt.append(all_pt[idx])
            
    return final_X, final_w, np.array(final_labels), np.array(final_pt)

# Load
jets_X, jets_w, labels, jets_pt = load_ef_data(max_jets_per_class=500)
N = len(jets_X)
print(f"Total jets loaded: {N}")

In [None]:
# Visualization & Utility Functions

def plot_jet_scatter(ax, x0, y0, X, w, scale=0.25, max_markersize=60, cmap='viridis', alpha=0.8):
    xs = x0 + X[:, 0] * scale
    ys = y0 + X[:, 1] * scale
    s = (w / (w.max() + 1e-12)) * max_markersize
    ax.scatter(xs, ys, s=s, c=w, cmap=cmap, alpha=alpha, edgecolors='none')

print("Visualization utilities defined.")

In [None]:
# ---------------------------------------------------------
# VISUALIZATION: OT Plan & Animation
# ---------------------------------------------------------
# We select two random jets and visualize the transport plan.
# Uses EnergyFlow (ef.emd.emd) for IRC-safe EMD calculation.

import matplotlib.animation as animation
from IPython.display import HTML

# CONFIGURATION: Select classes to compare
# 0: Gluon, 1: Quark, 2: Top
source_class_id = 1 # Quark
target_class_id = 2 # Top

source_name = class_names[source_class_id]
target_name = class_names[target_class_id]

print(f"Visualizing transport from {source_name} to {target_name}...")

# Select first available jet of each class
idx_source = np.where(labels == source_class_id)[0][0]
idx_target = np.where(labels == target_class_id)[0][0]

X_source, w_source = jets_X[idx_source], jets_w[idx_source]
X_target, w_target = jets_X[idx_target], jets_w[idx_target]

# Prepare events for EnergyFlow: (pt, y, phi)
# jets_X is (y, phi), jets_w is pt (normalized)
ev_source = np.column_stack((w_source, X_source))
ev_target = np.column_stack((w_target, X_target))

print("Computing OT plan using EnergyFlow (ef.emd.emd)...")
# ef.emd is a module, the function is ef.emd.emd
# It returns (dist, flow) when return_flow=True
# The flow is returned as a dense matrix of shape (N, M)
dist_val, gamma = ef.emd.emd(ev_source, ev_target, return_flow=True)

print(f"EMD Distance: {dist_val}")
print(f"Flow matrix shape: {gamma.shape}")
print(f"Max flow weight in gamma: {gamma.max()}")

# Plot Static Transport Plan
plt.figure(figsize=(12, 5))

# Plot 1: Source
plt.subplot(1, 3, 1)
plt.scatter(X_source[:, 0], X_source[:, 1], s=w_source*1000, c='r', alpha=0.6, label=source_name)
plt.title(f"Source ({source_name})")
plt.xlim(-0.8, 0.8); plt.ylim(-0.8, 0.8)
plt.grid(True, alpha=0.3)

# Plot 2: Target
plt.subplot(1, 3, 2)
plt.scatter(X_target[:, 0], X_target[:, 1], s=w_target*1000, c='b', alpha=0.6, label=target_name)
plt.title(f"Target ({target_name})")
plt.xlim(-0.8, 0.8); plt.ylim(-0.8, 0.8)
plt.grid(True, alpha=0.3)

# Plot 3: Transport Lines
plt.subplot(1, 3, 3)
plt.title(f"Transport Plan (EMD={dist_val:.4f})")
# Draw lines for significant transport
# Lower threshold to ensure visibility
threshold = gamma.max() * 0.01 
count_lines = 0
for i in range(len(X_source)):
    for j in range(len(X_target)):
        if gamma[i, j] > threshold:
            plt.plot([X_source[i, 0], X_target[j, 0]], [X_source[i, 1], X_target[j, 1]], 
                     'k-', alpha=0.2, lw=gamma[i, j]*200)
            count_lines += 1
            
print(f"Drawn {count_lines} transport lines (threshold={threshold:.6f}).")

plt.scatter(X_source[:, 0], X_source[:, 1], s=w_source*200, c='r', alpha=0.4)
plt.scatter(X_target[:, 0], X_target[:, 1], s=w_target*200, c='b', alpha=0.4)
plt.xlim(-0.8, 0.8); plt.ylim(-0.8, 0.8)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# ANIMATION
# ---------------------------------------------------------
# We interpolate between the two distributions: mu(t) = argmin_mu (1-t)W(mu0, mu) + t*W(mu1, mu)
# For discrete measures, this corresponds to moving particles along the transport lines.

fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-0.8, 0.8)
ax.set_ylim(-0.8, 0.8)
ax.grid(True, alpha=0.3)
ax.set_title(f"Optimal Transport Interpolation ({source_name} -> {target_name})")

# Prepare particle paths
# Each entry in gamma[i, j] represents a "packet" of energy moving from X_source[i] to X_target[j]
paths = []
weights = []
colors = []

# Use a very low threshold for animation to include most particles
anim_threshold = 1e-8

for i in range(len(X_source)):
    for j in range(len(X_target)):
        if gamma[i, j] > anim_threshold:
            start = X_source[i]
            end = X_target[j]
            paths.append((start, end))
            weights.append(gamma[i, j])

weights = np.array(weights)
# Normalize weights for display size
s_scale = 2000

scatter = ax.scatter([], [], s=[], c=[], alpha=0.6)

def update(frame):
    t = frame / 20.0 # 0 to 1
    
    # If t > 1, reverse (ping-pong)
    if t > 1.0: t = 2.0 - t
    
    current_positions = []
    current_colors = []
    
    for k, (start, end) in enumerate(paths):
        pos = (1 - t) * start + t * end
        current_positions.append(pos)
        # Color interpolation: Red (Source) -> Blue (Target)
        current_colors.append((1-t, 0, t))
        
    current_positions = np.array(current_positions)
    
    # FIX: Handle empty arrays to prevent IndexError
    if current_positions.shape[0] == 0:
        current_positions = np.zeros((0, 2))
    
    scatter.set_offsets(current_positions)
    scatter.set_sizes(weights * s_scale)
    scatter.set_color(current_colors)
    return scatter,

anim = animation.FuncAnimation(fig, update, frames=40, interval=100, blit=True)
plt.close()
HTML(anim.to_jshtml())

In [None]:
# OT COMPUTATION (EnergyFlow / CPU)
# ---------------------------------------------------------
# Uses EnergyFlow (ef.emd.emd) for all calculations to ensure IRC safety.
# This runs on CPU using the FastEMD C++ backend provided by EnergyFlow.

import multiprocessing
from joblib import Parallel, delayed
import time

def compute_ot_pair_ef(i, j, Xi, wi, Xj, wj):
    """
    Compute EMD between jet i and jet j using EnergyFlow.
    """
    # Construct events: (pt, y, phi)
    # Xi is (y, phi), wi is pt
    ev_i = np.column_stack((wi, Xi))
    ev_j = np.column_stack((wj, Xj))
    
    # Compute EMD
    # R=1.0 is standard, but EMD is linear in R anyway for fixed R.
    # We use the default settings which correspond to standard EMD.
    # Note: ef.emd is a module, the function is ef.emd.emd
    val = ef.emd.emd(ev_i, ev_j)
    return i, j, float(val)

print("Computing pairwise EMD using EnergyFlow (CPU parallel)...")
N = len(jets_X)
n_cores = min(8, multiprocessing.cpu_count())

# Generate pairs (upper triangle)
pairs = [(i, j) for i in range(N) for j in range(i+1, N)]
print(f"Total jets: {N}")
print(f"Total pairs to compute: {len(pairs)}")
print(f"Using {n_cores} cores.")

start_time = time.time()

# Run parallel computation
# We pass the data explicitly to avoid overhead if possible, 
# though joblib handles shared memory well for read-only data.
results = Parallel(n_jobs=n_cores, verbose=5)(
    delayed(compute_ot_pair_ef)(i, j, jets_X[i], jets_w[i], jets_X[j], jets_w[j]) for i, j in pairs
)

# Fill matrix
D = np.zeros((N, N))
for i, j, val in results:
    D[i, j] = val
    D[j, i] = val
np.fill_diagonal(D, 0.0)

end_time = time.time()
print(f"Computation complete in {end_time - start_time:.2f} seconds.")
print(f"Average time per pair: {(end_time - start_time)/len(pairs)*1000:.2f} ms")

In [None]:
# ---------------------------------------------------------
# ANALYSIS & PLOTS
# ---------------------------------------------------------

# ---------------------------------------------------------
# 1. Q/G Only t-SNE
# ---------------------------------------------------------
# Filter for Quark (1) and Gluon (0) only
qg_indices = np.where((labels == 0) | (labels == 1))[0]

if len(qg_indices) > 0:
    print(f"Running t-SNE for Q/G jets only ({len(qg_indices)} jets)...")
    
    # Extract submatrix
    D_qg = D[np.ix_(qg_indices, qg_indices)]
    labels_qg = labels[qg_indices]
    
    # Run t-SNE
    tsne_qg = TSNE(n_components=2, metric='precomputed', init='random', random_state=42, perplexity=30)
    emb_qg = tsne_qg.fit_transform(D_qg)
    
    # Plot
    plt.figure(figsize=(10, 8))
    for i in [0, 1]: # Gluon, Quark
        if i in labels_qg:
            mask = labels_qg == i
            plt.scatter(emb_qg[mask, 0], emb_qg[mask, 1], label=class_names[i], s=40, alpha=0.8)
            
    plt.title('t-SNE embedding: Quark vs Gluon Jets')
    plt.legend()
    plt.grid(True, alpha=0.2)
    plt.show()
else:
    print("No Q/G jets found for separate plot.")

# ---------------------------------------------------------
# 2. Full t-SNE (All Classes)
# ---------------------------------------------------------
print("Running t-SNE for all classes...")
tsne = TSNE(n_components=2, metric='precomputed', init='random', random_state=42, perplexity=30)
emb = tsne.fit_transform(D)

# Plot 1: Standard t-SNE with consistent legend
plt.figure(figsize=(10, 8))
present_classes = np.unique(labels)
colors = plt.cm.tab10(np.linspace(0, 1, 10))

for i in present_classes:
    mask = labels == i
    plt.scatter(emb[mask, 0], emb[mask, 1], label=class_names[i], s=40, alpha=0.8)

plt.title('t-SNE embedding of All Jets (Quark, Gluon, Top)')
plt.legend()
plt.grid(True, alpha=0.2)
plt.show()


In [None]:
# Plot 3: Morphological Confusion Matrix
# For each class i, show the i-sample closest to class j on average
n_classes = len(present_classes)
out_dir = 'figures/confusion_matrix'
import os
os.makedirs(out_dir, exist_ok=True)

if n_classes > 1:
    fig = plt.figure(figsize=(9, 9))

    # Modest mass cut to accentuate morphological differences (25th percentile)
    if 'jets_pt' in globals() or 'jets_pt' in locals():
        mass_thresh = np.percentile(jets_pt, 25)
    else:
        mass_thresh = 0.0
    print(f"Using mass threshold (25th percentile): {mass_thresh:.3f}")

    saved_files = []

    for i_idx, i in enumerate(present_classes):
        inds_i_all = np.where(labels == i)[0]
        # Apply mass filter but fall back to all indices if none pass
        if 'jets_pt' in globals() or 'jets_pt' in locals():
            inds_i = inds_i_all[jets_pt[inds_i_all] >= mass_thresh]
            if len(inds_i) == 0:
                inds_i = inds_i_all
        else:
            inds_i = inds_i_all

        for j_idx, j in enumerate(present_classes):
            inds_j_all = np.where(labels == j)[0]
            if 'jets_pt' in globals() or 'jets_pt' in locals():
                inds_j = inds_j_all[jets_pt[inds_j_all] >= mass_thresh]
                if len(inds_j) == 0:
                    inds_j = inds_j_all
            else:
                inds_j = inds_j_all

            # Submatrix of distances between class i and class j
            dsub = D[np.ix_(inds_i, inds_j)]

            # Find the jets in class i with smallest average distance to class j
            avg_dist = dsub.mean(axis=1)
            sorted_local = np.argsort(avg_dist)
            # Map to global indices and take up to 3 prototypes
            proto_globals = [inds_i[idx] for idx in sorted_local[:3]]

            # Plot main subplot using the most prototypical (rank 1)
            best_global_idx = proto_globals[0]
            ax = plt.subplot(n_classes, n_classes, i_idx * n_classes + j_idx + 1)
            # Increase scale so jets fill the subplot and use larger markers for visibility
            plot_jet_scatter(ax, 0, 0, jets_X[best_global_idx], jets_w[best_global_idx], scale=1.2, max_markersize=150)
            ax.set_xlim(-0.9, 0.9); ax.set_ylim(-0.9, 0.9)
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_aspect('equal')

            if i_idx == 0:
                plt.title(f"Closest to\n{class_names[j]}", fontsize=12)
            if j_idx == 0:
                plt.ylabel(f"From\n{class_names[i]}", fontsize=12, rotation=0, labelpad=40)

            # Save top-3 prototype images for this cell
            for rank, gidx in enumerate(proto_globals):
                fig_single = plt.figure(figsize=(3, 3))
                ax_s = fig_single.add_subplot(1, 1, 1)
                plot_jet_scatter(ax_s, 0, 0, jets_X[gidx], jets_w[gidx], scale=1.2, max_markersize=150)
                ax_s.set_xlim(-0.9, 0.9); ax_s.set_ylim(-0.9, 0.9)
                ax_s.axis('off')
                safe_i = str(class_names[i]).replace(' ', '_')
                safe_j = str(class_names[j]).replace(' ', '_')
                fname = os.path.join(out_dir, f"{safe_i}_to_{safe_j}_proto_rank{rank+1}.png")
                fig_single.savefig(fname, dpi=150, bbox_inches='tight', pad_inches=0.08)
                plt.close(fig_single)
                saved_files.append(fname)

    # Save the assembled grid figure as well
    grid_fname = os.path.join(out_dir, 'morph_confusion_matrix_grid.png')
    fig.savefig(grid_fname, dpi=200, bbox_inches='tight')
    saved_files.append(grid_fname)

    plt.suptitle('Morphological Confusion Matrix\\n(Representative jets)', y=1.02)
    plt.tight_layout()
    plt.show()

    print(f"Saved {len(saved_files)} files to {out_dir}:\n - " + "\n - ".join(saved_files))


In [None]:
# Absolute Jet Mass Distribution with Medoids

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Calculate absolute jet masses from constituents
# We compute this directly from the particle 4-vectors to ensure accuracy
print("Calculating jet masses from constituents...")
m_list = []
for i in range(len(jets_X)):
    # Reconstruct constituent 4-vectors
    # jets_w is normalized pT fraction, jets_pt is total pT
    pt_i = jets_w[i] * jets_pt[i]
    y_i = jets_X[i][:, 0]
    phi_i = jets_X[i][:, 1]
    
    # Standard invariant mass calculation
    # m^2 = E^2 - P^2
    # Assumes massless constituents (p ~ E)
    # E = pt * cosh(y)
    # Px = pt * cos(phi)
    # Py = pt * sin(phi)
    # Pz = pt * sinh(y)
    
    E = np.sum(pt_i * np.cosh(y_i))
    Px = np.sum(pt_i * np.cos(phi_i))
    Py = np.sum(pt_i * np.sin(phi_i))
    Pz = np.sum(pt_i * np.sinh(y_i))
    
    m2 = E**2 - Px**2 - Py**2 - Pz**2
    m_list.append(np.sqrt(max(0, m2)))

abs_masses = np.array(m_list)

fig = plt.figure(figsize=(14, 12))
plot_classes = [c for c in [2, 1, 0] if c in present_classes]
n_rows = len(plot_classes) + 1
height_ratios = [1]*len(plot_classes) + [3]

gs = fig.add_gridspec(n_rows, 1, height_ratios=height_ratios, hspace=0.1)

# Histogram (Bottom)
ax_hist = fig.add_subplot(gs[-1])
# Range: 0 to 250 GeV (typical for these datasets)
counts, bins, _ = ax_hist.hist([abs_masses[labels==i] for i in present_classes], 
                               bins=20, range=(0, 250), 
                               stacked=True, density=True, alpha=0.5, 
                               label=[class_names[i] for i in present_classes])
ax_hist.legend(fontsize='x-large')
ax_hist.set_xlabel("Jet Mass [GeV]", fontsize='x-large')
ax_hist.set_ylabel("Density", fontsize='x-large')
ax_hist.set_xlim(0, 250)

# Medoid Rows
for row_idx, cls_id in enumerate(plot_classes):
    ax_row = fig.add_subplot(gs[row_idx], sharex=ax_hist)
    ax_row.axis('off')
    ax_row.set_xlim(0, 250)
    ax_row.set_ylim(-1, 1)
    
    ax_row.text(-0.02, 0.5, f"{class_names[cls_id]}\nMedoids", 
                transform=ax_row.transAxes, ha='right', va='center', fontsize=12, fontweight='bold')

    for k in range(len(bins)-1):
        mask = (abs_masses >= bins[k]) & (abs_masses < bins[k+1]) & (labels == cls_id)
        idxs = np.where(mask)[0]
        
        if len(idxs) > 2:
            d_sub = D[np.ix_(idxs, idxs)]
            medoid_local = np.argmin(d_sub.sum(axis=1))
            medoid_idx = idxs[medoid_local]
            
            x_c = (bins[k] + bins[k+1]) / 2
            x_norm = (x_c - 0) / 250.0
            
            ax_ins = inset_axes(ax_row, width="80%", height="90%", 
                               bbox_to_anchor=(x_norm - 0.5/20, 0, 1/20, 1), 
                               bbox_transform=ax_row.transAxes, loc='center')
            
            plot_jet_scatter(ax_ins, 0, 0, jets_X[medoid_idx], jets_w[medoid_idx], scale=0.8, max_markersize=40)
            ax_ins.set_xticks([])
            ax_ins.set_yticks([])
            ax_ins.axis('off')

plt.show()

# ---------------------------------------------------------
# Jet Multiplicity Distribution with Medoids
# ---------------------------------------------------------
print("Generating jet multiplicity distribution + medoids...")

# Rebinning factor (group integer multiplicities into bins of width mult_rebin)
# Set to 1 for no rebinning, 2 for pairs (0-1,2-3,...), etc.
mult_rebin = 5  # <-- change this value to rebin multiplicities

# Compute multiplicity (number of constituents) per jet
multiplicities = np.array([len(w) for w in jets_w])
# Grouped multiplicities
grouped = multiplicities // mult_rebin
group_max = int(grouped.max()) if grouped.size > 0 else 0

fig_mult = plt.figure(figsize=(14, 12))
plot_classes = [c for c in [2, 1, 0] if c in present_classes]
n_rows = len(plot_classes) + 1
height_ratios = [1]*len(plot_classes) + [3]

gs2 = fig_mult.add_gridspec(n_rows, 1, height_ratios=height_ratios, hspace=0.1)

# Histogram (Bottom) for grouped multiplicity
ax_hist_m = fig_mult.add_subplot(gs2[-1])
# Build grouped arrays per class
grouped_per_class = [grouped[labels==i] for i in present_classes]
# Bins for grouped values
bins_mult = np.arange(0, group_max + 2) - 0.5
counts_m, _, _ = ax_hist_m.hist(grouped_per_class, 
                                 bins=bins_mult, stacked=True, density=True, alpha=0.6, 
                                 label=[class_names[i] for i in present_classes])
# xticks: label ranges
xticks = np.arange(0, group_max+1)
xtick_labels = [f"{g*mult_rebin}-{g*mult_rebin + mult_rebin - 1}" if mult_rebin>1 else f"{g*mult_rebin}" for g in xticks]
ax_hist_m.set_xticks(xticks)
ax_hist_m.set_xticklabels(xtick_labels, rotation=45)
ax_hist_m.set_xlabel("Jet constituent multiplicity", fontsize='x-large')
ax_hist_m.set_ylabel("Density", fontsize='x-large')
ax_hist_m.set_xlim(-0.5, group_max + 0.5)
ax_hist_m.legend(fontsize='x-large')

# Medoid rows for multiplicity (same layout as mass medoids)
for row_idx, cls_id in enumerate(plot_classes):
    ax_row = fig_mult.add_subplot(gs2[row_idx], sharex=ax_hist_m)
    ax_row.axis('off')
    ax_row.set_xlim(-0.5, group_max + 0.5)
    ax_row.set_ylim(-1, 1)
    
    ax_row.text(-0.02, 0.5, f"{class_names[cls_id]}\nMult. Medoids", 
                transform=ax_row.transAxes, ha='right', va='center', fontsize=12, fontweight='bold')

    # Iterate over grouped multiplicity bins
    for g in range(0, group_max+1):
        mask = (grouped == g) & (labels == cls_id)
        idxs = np.where(mask)[0]
        if len(idxs) > 2:
            d_sub = D[np.ix_(idxs, idxs)]
            medoid_local = np.argmin(d_sub.sum(axis=1))
            medoid_idx = idxs[medoid_local]

            # place inset roughly at the appropriate multiplicity location
            if group_max > 0:
                x_norm = g / group_max
            else:
                x_norm = 0.5

            width_frac = 1.0 / (group_max + 1 if group_max > 0 else 1)
            ax_ins = inset_axes(ax_row, width="80%", height="90%", 
                                bbox_to_anchor=(x_norm - 0.5*width_frac, 0, width_frac, 1), 
                                bbox_transform=ax_row.transAxes, loc='center')
            plot_jet_scatter(ax_ins, 0, 0, jets_X[medoid_idx], jets_w[medoid_idx], scale=0.8, max_markersize=40)
            ax_ins.set_xticks([])
            ax_ins.set_yticks([])
            ax_ins.axis('off')

plt.tight_layout()
plt.show()


In [None]:
# ---------------------------------------------------------
# DETAILED MEDOID VISUALIZATION (Specific Mass Bin)
# ---------------------------------------------------------
# Plot a larger version of the Quark, Gluon, and Top medoids 
# for a specific mass bin (e.g., near the W/Top mass peak or a generic QCD mass).

# Ensure we have absolute jet masses (in GeV)
if 'abs_masses' not in locals():
    if 'masses' in locals() and 'jets_pt' in locals():
        print("Calculating absolute masses from normalized masses...")
        abs_masses = masses * jets_pt
    else:
        print("Calculating jet masses from constituents...")
        # Compute mass from jets_X, jets_w, jets_pt
        m_list = []
        for i in range(len(jets_X)):
            pt_i = jets_w[i] * jets_pt[i]
            y_i = jets_X[i][:, 0]
            phi_i = jets_X[i][:, 1]
            
            E = np.sum(pt_i * np.cosh(y_i))
            Px = np.sum(pt_i * np.cos(phi_i))
            Py = np.sum(pt_i * np.sin(phi_i))
            Pz = np.sum(pt_i * np.sinh(y_i))
            
            m2 = E**2 - Px**2 - Py**2 - Pz**2
            m_list.append(np.sqrt(max(0, m2)))
        abs_masses = np.array(m_list)

# Target Mass Range in GeV
# Top Quark mass is approx 173 GeV. W Boson is approx 80 GeV.
target_mass_min = 170.0 
target_mass_max = 175.0

print(f"Finding representative jets (medoids) in mass range [{target_mass_min:.1f}, {target_mass_max:.1f}] GeV...")

# Find indices in this bin using absolute mass
mask_bin = (abs_masses >= target_mass_min) & (abs_masses < target_mass_max)

plt.figure(figsize=(15, 5))

plot_classes = [c for c in [1, 0, 2] if c in present_classes] # Quark, Gluon, Top

for i, cls_id in enumerate(plot_classes):
    # Filter by class
    mask = mask_bin & (labels == cls_id)
    idxs = np.where(mask)[0]
    
    if len(idxs) > 0:
        # Compute medoid
        d_sub = D[np.ix_(idxs, idxs)]
        medoid_local = np.argmin(d_sub.sum(axis=1))
        medoid_idx = idxs[medoid_local]
        
        # Plot
        ax = plt.subplot(1, 3, i + 1)
        plot_jet_scatter(ax, 0, 0, jets_X[medoid_idx], jets_w[medoid_idx], scale=0.8, max_markersize=200)
        
        ax.set_xlim(-1.0, 1.0)
        ax.set_ylim(-1.0, 1.0)
        ax.set_title(f"{class_names[cls_id]} Medoid\nMass $\\in$ [{target_mass_min}, {target_mass_max}] GeV", fontsize=14)
        ax.set_xlabel("$\\eta$")
        if i == 0: ax.set_ylabel("$\\phi$")
        ax.grid(True, alpha=0.3)
    else:
        print(f"No {class_names[cls_id]} jets found in this mass bin.")
        # Create empty placeholder
        ax = plt.subplot(1, 3, i + 1)
        ax.text(0.5, 0.5, "No Jets Found", ha='center', va='center')
        ax.set_title(f"{class_names[cls_id]} Medoid")
        ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# ---------------------------------------------------------
# CLASSIFICATION: Top vs. QCD (Quark/Gluon) using EMD
# ---------------------------------------------------------
# We use the pairwise EMD matrix to perform k-Nearest Neighbor (k-NN) classification.
# This demonstrates the power of the metric space structure for distinguishing jet classes.
# Reference: Komiske, Metodiev, Thaler, "The Metric Space of Collider Events", PRL 123 (2019)

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc

print("Running Top vs. QCD Classification using EMD-based k-NN...")

# 1. Prepare Data
# Signal: Top (Label 2)
# Background: QCD (Quark=1, Gluon=0)
is_top = (labels == 2)
is_qcd = (labels == 0) | (labels == 1)

# Filter dataset to only include these (should be all, but good to be safe)
valid_mask = is_top | is_qcd
valid_indices = np.where(valid_mask)[0]

# Create binary labels for classification: 1 for Top, 0 for QCD
y_binary = (labels[valid_indices] == 2).astype(int)
indices = np.arange(len(valid_indices))

# 2. Split into Train and Test
# We split the *indices* of our valid data so we can slice the distance matrix D
idx_train, idx_test, y_train, y_test = train_test_split(
    indices, y_binary, test_size=0.3, random_state=42, stratify=y_binary
)

# Map back to global indices in D
global_idx_train = valid_indices[idx_train]
global_idx_test = valid_indices[idx_test]

print(f"Training set: {len(idx_train)} jets")
print(f"Test set:     {len(idx_test)} jets")

# 3. Prepare Distance Matrices for k-NN
# k-NN with 'precomputed' metric requires:
# Fit: (n_train, n_train) distance matrix
# Predict: (n_test, n_train) distance matrix

# Extract submatrix for training (train vs train)
D_train = D[np.ix_(global_idx_train, global_idx_train)]

# Extract submatrix for testing (test vs train)
D_test = D[np.ix_(global_idx_test, global_idx_train)]

# 4. Train k-NN Classifier
# k is a hyperparameter. A typical heuristic is sqrt(N_train), but we can pick a fixed value.
k = 16
knn = KNeighborsClassifier(n_neighbors=k, metric='precomputed')
knn.fit(D_train, y_train)

# 5. Predict on Test Set
# Get probability of being Top (class 1)
# The probability is the fraction of the k nearest neighbors that are Top jets
y_score = knn.predict_proba(D_test)[:, 1]

# 6. Compute ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, y_score)
roc_auc = auc(fpr, tpr)

# 7. Plot Results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# ROC Curve
ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'EMD k-NN (k={k}) (AUC = {roc_auc:.3f})')
ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax1.set_xlim([0.0, 1.0])
ax1.set_ylim([0.0, 1.05])
ax1.set_xlabel('False Positive Rate (QCD tagged as Top)')
ax1.set_ylabel('True Positive Rate (Top efficiency)')
ax1.set_title('ROC Curve: Top vs. QCD Classification')
ax1.legend(loc="lower right")
ax1.grid(True, alpha=0.3)

# Discriminant Distribution
ax2.hist(y_score[y_test==0], bins=20, range=(0, 1), density=True, alpha=0.5, color='blue', label='QCD (True)')
ax2.hist(y_score[y_test==1], bins=20, range=(0, 1), density=True, alpha=0.5, color='red', label='Top (True)')
ax2.set_xlabel(f'k-NN Discriminant (P(Top) with k={k})')
ax2.set_ylabel('Density')
ax2.set_title('Classifier Score Distribution')
ax2.legend(loc='upper center')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Classification AUC: {roc_auc:.4f}")

In [None]:
# ---------------------------------------------------------
# CLASSIFICATION: Top vs. QCD using Energy Flow Network (EFN)
# ---------------------------------------------------------
# We train a deep neural network (EFN) on the particle sets.
# EFNs are IRC-safe architectures designed for point clouds.
# Reference: Komiske, Metodiev, Thaler, "Energy Flow Networks: Deep Sets for Particle Physics", JHEP 01 (2019) 121
# https://arxiv.org/abs/1810.05165

# Install TensorFlow and tf_keras (required for compatibility with EnergyFlow on newer TF versions)
# We use --no-deps to avoid messing up other dependencies if possible, but here we need to ensure they are installed.
!pip install -q tensorflow tf_keras

import tensorflow as tf
import os
import importlib
import energyflow

# Set environment variable to ensure EnergyFlow uses the correct Keras backend if needed
os.environ['TF_KERAS'] = '1' 

import energyflow.archs
importlib.reload(energyflow.archs)

# Now import EnergyFlow architectures
try:
    from energyflow.archs import EFN
except ImportError:
    # If it still fails, it might be a deeper issue with the backend. 
    # We try to reload the specific submodule if possible or warn the user.
    print("Could not import EFN after reload. Trying to reload energyflow.archs.efn directly...")
    try:
        import energyflow.archs.efn
        importlib.reload(energyflow.archs.efn)
        from energyflow.archs import EFN
    except Exception as e:
        print(f"Failed to import EFN: {e}")
        print("Please RESTART THE RUNTIME (Runtime > Restart session) and run this cell again.")
        raise

from sklearn.model_selection import train_test_split

print("Running Top vs. QCD Classification using EFN...")

# 1. Prepare Data
# We need to convert our list of jets into a padded numpy array (N, M, 3)
# Features: (z, y, phi) where z is pT fraction

# Filter for Top (2) and QCD (0, 1)
is_top = (labels == 2)
is_qcd = (labels == 0) | (labels == 1)
valid_mask = is_top | is_qcd
valid_indices = np.where(valid_mask)[0]

# Find max particles in the dataset
max_len = max(len(jets_w[i]) for i in valid_indices)
print(f"Max particles per jet: {max_len}")

# Initialize arrays
# We separate weights (z) and features (y, phi) because EFN expects them as separate inputs
X_z = np.zeros((len(valid_indices), max_len))
X_p = np.zeros((len(valid_indices), max_len, 2)) # (y, phi)
Y_efn = np.zeros(len(valid_indices))

for i, idx in enumerate(valid_indices):
    # Get jet data
    x_jet = jets_X[idx] # (M, 2) -> (y, phi)
    w_jet = jets_w[idx] # (M,) -> z (normalized pt)
    
    n_particles = len(w_jet)
    
    # Fill arrays
    X_z[i, :n_particles] = w_jet
    X_p[i, :n_particles, 0] = x_jet[:, 0] # y
    X_p[i, :n_particles, 1] = x_jet[:, 1] # phi
    
    # Label: 1 for Top, 0 for QCD
    if labels[idx] == 2:
        Y_efn[i] = 1
    else:
        Y_efn[i] = 0

print(f"Data shapes: Z={X_z.shape}, P={X_p.shape}")
print(f"Labels shape: {Y_efn.shape}")

# 2. Split into Train/Val/Test
# We split all arrays simultaneously
X_z_train, X_z_test, X_p_train, X_p_test, Y_train, Y_test = train_test_split(
    X_z, X_p, Y_efn, test_size=0.2, random_state=42, stratify=Y_efn
)

X_z_train, X_z_val, X_p_train, X_p_val, Y_train, Y_val = train_test_split(
    X_z_train, X_p_train, Y_train, test_size=0.2, random_state=42, stratify=Y_train
)

print(f"Train: {len(Y_train)}, Val: {len(Y_val)}, Test: {len(Y_test)}")

# 3. Define EFN Architecture
# input_dim=2 for (y, phi) - The weight z is handled separately by the architecture
# Phi_sizes: Dense layers for per-particle mapping (Latent space dim is last layer)
# F_sizes: Dense layers for global mapping (Classifier)
efn = EFN(input_dim=2, 
          Phi_sizes=[100, 100, 128], 
          F_sizes=[100, 100, 100], 
          output_dim=1, 
          output_act='sigmoid',
          loss='binary_crossentropy',
          optimizer='adam',
          metrics=['accuracy'])

# 4. Train
print("Training EFN...")
# EFN expects a list of inputs: [X_z, X_p]
history = efn.fit([X_z_train, X_p_train], Y_train, 
                  epochs=20, 
                  batch_size=500, 
                  validation_data=([X_z_val, X_p_val], Y_val), 
                  verbose=1)

# 5. Evaluate
print("Evaluating on Test Set...")
preds = efn.predict([X_z_test, X_p_test])
fpr_efn, tpr_efn, thresholds_efn = roc_curve(Y_test, preds)
roc_auc_efn = auc(fpr_efn, tpr_efn)

# 6. Plot Results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# ROC Curve Comparison
ax1.plot(fpr_efn, tpr_efn, color='green', lw=2, label=f'EFN (AUC = {roc_auc_efn:.3f})')
# Add previous k-NN result if available
if 'fpr' in locals() and 'tpr' in locals():
    ax1.plot(fpr, tpr, color='darkorange', lw=2, linestyle='--', label=f'EMD k-NN (AUC = {roc_auc:.3f})')

ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle=':')
ax1.set_xlim([0.0, 1.0])
ax1.set_ylim([0.0, 1.05])
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')
ax1.set_title('ROC Curve Comparison')
ax1.legend(loc="lower right")
ax1.grid(True, alpha=0.3)

# Training History
ax2.plot(history.history['loss'], label='Train Loss')
ax2.plot(history.history['val_loss'], label='Val Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('EFN Training History')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"EFN Classification AUC: {roc_auc_efn:.4f}")

In [None]:
# ---------------------------------------------------------
# t-SNE: Same embedding colored by jet mass
# ---------------------------------------------------------
# This cell plots the previously computed `emb` (t-SNE embedding)
# and colors points by `abs_masses`. If `abs_masses` is missing,
# it will be computed (may take a moment).

if 'emb' not in globals() and 'emb' not in locals():
    print("t-SNE embedding 'emb' not found. Run the Full t-SNE cell first.")
else:
    if 'abs_masses' not in globals() and 'abs_masses' not in locals():
        print("'abs_masses' not found â€” computing from constituents (this may take some time)...")
        m_list = []
        for i in range(len(jets_X)):
            pt_i = jets_w[i] * jets_pt[i]
            y_i = jets_X[i][:, 0]
            phi_i = jets_X[i][:, 1]
            E = np.sum(pt_i * np.cosh(y_i))
            Px = np.sum(pt_i * np.cos(phi_i))
            Py = np.sum(pt_i * np.sin(phi_i))
            Pz = np.sum(pt_i * np.sinh(y_i))
            m2 = E**2 - Px**2 - Py**2 - Pz**2
            m_list.append(np.sqrt(max(0, m2)))
        abs_masses = np.array(m_list)

    plt.figure(figsize=(10, 8))
    sc = plt.scatter(emb[:, 0], emb[:, 1], c=abs_masses, cmap='plasma', s=40, alpha=0.85)
    cbar = plt.colorbar(sc)
    cbar.set_label('Jet mass [GeV]')
    plt.title('t-SNE embedding colored by jet mass')
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.grid(True, alpha=0.2)
    plt.show()


In [None]:
# ---------------------------------------------------------
# ANIMATION (HD + configurable colormap)
# ---------------------------------------------------------
# Produces a higher-resolution transport-based prototype-cycle GIF and lets you
# choose the colormap via `cmap_name`. This cell is a drop-in replacement
# if you prefer HD output without modifying the earlier cells.

import os
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML
import matplotlib.pyplot as plt

out_dir = 'figures/confusion_matrix'
os.makedirs(out_dir, exist_ok=True)

# Configurable parameters
N = 3
frames_per_transition = 20
class_order = [1, 0, 2]
cmap_name = 'plasma'  # change to any matplotlib colormap
gif_dpi = 200         # increase for higher resolution (e.g., 200 or 300)
figsize = (6, 6)

# Build prototypes
prototypes = {}
for cls in class_order:
    if cls in present_classes:
        inds = np.where(labels == cls)[0]
        if len(inds) == 0:
            continue
        d_sub = D[np.ix_(inds, inds)]
        avg_dist = d_sub.mean(axis=1)
        sorted_local = np.argsort(avg_dist)
        proto_globals = [int(inds[idx]) for idx in sorted_local[:N]]
        prototypes[cls] = proto_globals

# Sequence
seq = []
for r in range(N):
    for cls in class_order:
        if cls in prototypes and len(prototypes[cls]) > r:
            seq.append((cls, prototypes[cls][r]))

if len(seq) < 2:
    print("Not enough prototypes to animate transport. Found:", seq)
else:
    # Precompute flows
    flows = []
    for k in range(len(seq)):
        g1 = seq[k][1]
        g2 = seq[(k+1) % len(seq)][1]
        X1 = np.asarray(jets_X[g1])
        w1 = np.asarray(jets_w[g1])
        X2 = np.asarray(jets_X[g2])
        w2 = np.asarray(jets_w[g2])

        ev1 = np.column_stack((w1, X1))
        ev2 = np.column_stack((w2, X2))

        try:
            _, gamma = ef.emd.emd(ev1, ev2, return_flow=True)
            gamma = np.asarray(gamma)
            r = min(gamma.shape[0], X1.shape[0])
            c = min(gamma.shape[1], X2.shape[0])
            gamma = gamma[:r, :c]
        except Exception as e:
            print(f"ef.emd.emd failed for pair {g1}->{g2}: {e}. Using None for gamma.")
            gamma = None

        flows.append({'X1': X1, 'w1': w1, 'X2': X2, 'w2': w2, 'gamma': gamma, 'pair': (g1, g2)})

    total_frames = len(flows) * frames_per_transition
    fig, ax = plt.subplots(figsize=figsize)

    s_scale = 2500

    cmap = plt.get_cmap(cmap_name)

    def update(frame):
        seg = frame // frames_per_transition
        f_in_seg = frame % frames_per_transition
        t = f_in_seg / float(max(1, frames_per_transition - 1))

        data = flows[seg]
        X1, X2, gamma, w1 = data['X1'], data['X2'], data['gamma'], data['w1']

        ax.clear()
        ax.axis('off')
        ax.set_xlim(-0.9, 0.9)
        ax.set_ylim(-0.9, 0.9)
        ax.set_aspect('equal')

        if gamma is None:
            # fallback: linear interpolation, color by source weight
            n = max(len(X1), len(X2))
            X1p = np.vstack([X1, np.tile(X1[-1], (n - len(X1), 1))]) if len(X1) < n else X1[:n]
            X2p = np.vstack([X2, np.tile(X2[-1], (n - len(X2), 1))]) if len(X2) < n else X2[:n]
            pos = (1 - t) * X1p + t * X2p
            wsrc = np.concatenate([w1, np.zeros(max(0, n - len(w1)))])
            if wsrc.max() > 0:
                colors = wsrc / (wsrc.max() + 1e-12)
            else:
                colors = np.ones_like(wsrc)
            ax.scatter(pos[:, 0], pos[:, 1], s=30, c=colors, cmap=cmap, alpha=0.9, edgecolors='none')
        else:
            thresh = max(gamma.max() * 1e-8, 1e-12)
            inds = np.argwhere(gamma > thresh)
            if inds.size > 0:
                xs = []
                ys = []
                ss = []
                cs = []
                gmax = gamma.max()
                for (i, j) in inds:
                    start = X1[i]
                    end = X2[j]
                    pos = (1 - t) * start + t * end
                    q = gamma[i, j]
                    xs.append(pos[0])
                    ys.append(pos[1])
                    ss.append(max(1.0, q * s_scale))
                    cs.append(q / (gmax + 1e-12))
                ax.scatter(xs, ys, s=ss, c=cs, cmap=cmap, alpha=0.95, edgecolors='none')

        return []

    anim = FuncAnimation(fig, update, frames=total_frames, interval=80, blit=False)

    gif_fname = os.path.join(out_dir, f'prototype_transport_cycle_top{N}_q_g_t_{cmap_name}_hd.gif')
    print(f"Saving HD transport-based prototype-cycle GIF to {gif_fname} (dpi={gif_dpi}) ...")
    anim.save(gif_fname, writer=PillowWriter(fps=12), dpi=gif_dpi)
    plt.close(fig)

    print(f"Saved HD transport-based prototype-cycle GIF: {gif_fname}")
    try:
        display(HTML(f"<img src=\"{gif_fname}\" style=\"max-width:480px;\">"))
    except Exception:
        pass


In [None]:
# ---------------------------------------------------------
# FRACTAL CORRELATION DIMENSION
# ---------------------------------------------------------
# We compute the correlation dimension of the dataset manifold
# by analyzing the distribution of pairwise distances.

print("Computing Fractal Correlation Dimensions...")

# Prepare bins for the histogram of distances (EMDs)
# Range: 10^-2 to 10^0 (0.01 to 1.0)
bins = 10**np.linspace(-2, 0, 60)
reg = 10**-30 # Regularization to avoid log(0)
midbins = (bins[:-1] + bins[1:])/2
dmidbins = np.log(midbins[1:]) - np.log(midbins[:-1]) + reg
midbins2 = (midbins[:-1] + midbins[1:])/2

plt.figure(figsize=(8, 6))

# Iterate over each class to compute its dimension
for i in present_classes:
    # Get indices for this class
    inds = np.where(labels == i)[0]
    
    # Extract sub-matrix of distances
    # We only need the upper triangle (excluding diagonal)
    # D is symmetric, so triu is sufficient.
    d_sub = D[np.ix_(inds, inds)]
    
    # Get upper triangle values (k=1 removes diagonal 0s)
    # Flatten to 1D array
    emd_vals = d_sub[np.triu_indices_from(d_sub, k=1)]
    
    # Compute histogram of distances
    # counts is the number of pairs with distance in each bin
    # We use cumsum because correlation sum C(r) is the fraction of pairs with dist < r
    hist_counts, _ = np.histogram(emd_vals, bins=bins)
    cdf_counts = np.cumsum(hist_counts)
    
    # Compute local slope of log(C(r)) vs log(r)
    # dim = d(log C(r)) / d(log r)
    dims = (np.log(cdf_counts[1:] + reg) - np.log(cdf_counts[:-1] + reg)) / dmidbins
    
    # Plot
    plt.plot(midbins2, dims, '-', label=f'{class_names[i]} Jets', alpha=0.8, linewidth=2)

# Styling
plt.xscale('log')
plt.xlabel('Energy Scale (EMD)')
plt.ylabel('Correlation Dimension')
plt.xlim(0.02, 1.0)
plt.ylim(0, 10)
plt.legend(loc='best', frameon=False)
plt.title('Fractal Correlation Dimension vs. Scale')
plt.grid(True, which="both", ls="-", alpha=0.2)

plt.show()