In [9]:
# ---------------------------------------------------------
# GOOGLE COLAB SETUP
# ---------------------------------------------------------
import sys
import os
import subprocess
import re

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Detected Google Colab environment.")
    
    # ---------------------------------------------------------
    # 1. Smart Installation Check
    # ---------------------------------------------------------
    # We check the installed NumPy version via pip *before* deciding to install.
    # This prevents the cell from hanging on re-runs after a restart.
    
    current_numpy_version = "0.0.0"
    try:
        # Run 'pip show numpy' to get the version on disk
        res = subprocess.run([sys.executable, "-m", "pip", "show", "numpy"], capture_output=True, text=True)
        m = re.search(r"Version:\s*(\d+\.\d+\.\d+)", res.stdout)
        if m:
            current_numpy_version = m.group(1)
    except Exception:
        pass

    print(f"Current NumPy version on disk: {current_numpy_version}")
    
    # Determine if we need to install/downgrade
    needs_install = False
    
    # Condition 1: NumPy must be 1.x
    if not current_numpy_version.startswith("1."):
        print("NumPy 2.x (or unknown) detected. Downgrade required.")
        needs_install = True
    else:
        # Condition 2: Check if pyshaper is installed
        res_pkg = subprocess.run([sys.executable, "-m", "pip", "show", "pyshaper"], capture_output=True, text=True)
        if "Name: pyshaper" not in res_pkg.stdout:
            print("pyshaper not found. Installation required.")
            needs_install = True

    if needs_install:
        print("Installing required packages (this may take a minute)...")
        # AGGRESSIVE COMPATIBILITY FIX:
        # Force reinstall to ensure we get numpy<2.0.0 and compatible scipy/sklearn
        !pip install -q -U --force-reinstall "numpy<2.0.0" "scipy<1.13.0" "scikit-learn<1.5.0" energyflow pot uproot awkward vector pyshaper
        
        print("\n" + "="*80)
        print("INSTALLATION COMPLETE.")
        print("CRITICAL: You MUST restart the Colab runtime now.")
        print("1. Go to menu: Runtime > Restart session")
        print("2. Run this cell again (it will skip installation next time)")
        print("="*80 + "\n")
    else:
        print("Environment appears correct (NumPy 1.x installed). Skipping installation.")

    # ---------------------------------------------------------
    # 2. Download Dataset
    # ---------------------------------------------------------
    folder_id = '13Rs54cAg-MyUocZMLRRozjeGcRdfjDJ5'
    output_folder = 'jetclass'
    
    if not os.path.exists(output_folder):
        print(f"Downloading dataset to {output_folder}/...")
        !pip install -q -U --no-cache-dir gdown --pre
        import gdown
        gdown.download_folder(id=folder_id, output=output_folder, quiet=False)
    else:
        print(f"Folder '{output_folder}' already exists. Skipping download.")
        
else:
    print("Not running in Google Colab. Assuming local environment.")

Not running in Google Colab. Assuming local environment.


In [10]:
# Setup and Imports
import numpy as np
import matplotlib.pyplot as plt

# CRITICAL: Check NumPy version before importing other libraries
print(f"NumPy version: {np.__version__}")
if np.__version__ >= '2.0.0':
    msg = (
        f"Detected NumPy {np.__version__}, but this notebook requires NumPy < 2.0.0.\n"
        "Please run the 'GOOGLE COLAB SETUP' cell above to install the correct versions,\n"
        "then RESTART the runtime (Runtime > Restart session) and run this cell again."
    )
    raise RuntimeError(msg)

import ot
import uproot
import energyflow as ef

print("Libraries imported successfully.")
print(f"EnergyFlow version: {ef.__version__}")
print(f"Uproot version: {uproot.__version__}")

NumPy version: 1.26.4
Libraries imported successfully.
EnergyFlow version: 1.4.0
Uproot version: 5.6.9


## 1. Data Loading: JetClass Dataset

The [JetClass dataset](https://zenodo.org/records/6619768) is a large-scale dataset for jet tagging. We will load the data from ROOT files located in the `inputs/` directory.

*   **QCD Jets:** Background jets, typically 1-prong.
*   **W Jets:** Signal jets, 2-prong structure.
*   **Top Jets:** Signal jets, 3-prong structure.

We will use `uproot` to load the particle kinematics and convert them into the $(\eta, \phi)$ plane for Optimal Transport analysis.

In [11]:
# Visualization & Utility Functions

def plot_jet(X, w, title="Jet"):
    """
    Plot a single jet as a scatter plot in the eta-phi plane.
    """
    plt.scatter(X[:, 0], X[:, 1], s=w*1000, alpha=0.6, c=w, cmap='viridis')
    plt.xlim(-1.2, 1.2)
    plt.ylim(-1.2, 1.2)
    plt.xlabel("$\\eta$")
    plt.ylabel("$\\phi$")
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.colorbar(label="$p_T$ fraction")

def jet_to_image(X, w, grid_size=28, r=1.2):
    """
    Rasterize a jet (point cloud) into a grid image.
    """
    img, _, _ = np.histogram2d(
        X[:, 0], X[:, 1], 
        bins=grid_size, 
        range=[[-r, r], [-r, r]], 
        weights=w
    )
    return img.T  # Transpose to match image coordinates

print("Visualization utilities defined.")

Visualization utilities defined.


In [12]:
# ---------------------------------------------------------
# EXPERIMENTS: Multi-class jets (JetClass inputs), OT distances via Sinkhorn or EnergyFlow
# ---------------------------------------------------------
import os
import glob
import multiprocessing
import numpy as np
from joblib import Parallel, delayed
from sklearn.manifold import TSNE

# Try to import GPU libraries
try:
    import torch
    import geomloss
    USE_GPU = torch.cuda.is_available()
    if USE_GPU:
        print(f"GPU detected: {torch.cuda.get_device_name(0)}")
    else:
        print("No GPU detected. Using CPU.")
except ImportError:
    USE_GPU = False
    print("Torch/GeomLoss not found. Using CPU.")

# Parameters: 10 JetClass categories
classes = {
    0: { 'name': 'W->qq',        'file_pattern': 'WToQQ' },
    1: { 'name': 'Z->qq',        'file_pattern': 'ZToQQ' },    
    2: { 'name': 'TTbar->bqq',   'file_pattern': 'TTBar_' },
    # 3: { 'name': 'Z->nunu',      'file_pattern': 'ZJetsToNuNu' },
    # 4: { 'name': 'TTbar->blnu',  'file_pattern': 'TTBarLep' },
    # 5: { 'name': 'H->bb',        'file_pattern': 'HToBB' },
    # 6: { 'name': 'H->cc',        'file_pattern': 'HToCC' },
    # 7: { 'name': 'H->gg',        'file_pattern': 'HToGG' },
    # 8: { 'name': 'H->WW(2q1l)',  'file_pattern': 'HToWW2Q1L' },
    # 9: { 'name': 'H->WW(4q)',    'file_pattern': 'HToWW4Q' }
}

def get_label_from_filename(fname):
    for label, info in classes.items():
        if info['file_pattern'] in fname:
            return label
    return -1

# Loader: load jets from jetclass/ (ROOT files)
def load_jets_from_inputs(max_jets_per_file=200, max_particles=128):
    # Look in jetclass/ directory
    files = glob.glob('jetclass/*.root')
    jets_X = []
    jets_w = []
    labels = []
    
    if not files:
        print("WARNING: No .root files found in jetclass/ directory!")
        return [], [], []

    print(f"Found {len(files)} ROOT files. Loading...")

    for fpath in files:
        fname = os.path.basename(fpath)
        lab = get_label_from_filename(fname)
        
        if lab == -1:
            # print(f"Skipping file with unknown class: {fname}")
            continue
            
        print(f"Loading {classes[lab]['name']} from {fname}...")

        try:
            with uproot.open(fpath) as root:
                # Search for any tree and examine its branches
                for key in root.keys():
                    try:
                        obj = root[key]
                        if hasattr(obj, 'arrays'):
                            arrs = obj.arrays(library='np')
                            keys = arrs.keys()
                            
                            # Check for px, py, pz
                            px_key = next((k for k in keys if k.endswith('px')), None)
                            py_key = next((k for k in keys if k.endswith('py')), None)
                            pz_key = next((k for k in keys if k.endswith('pz')), None)
                            
                            if px_key and py_key and pz_key:
                                px = arrs[px_key]
                                py = arrs[py_key]
                                pz = arrs[pz_key]
                                n_jets = min(len(px), max_jets_per_file)
                                for i in range(n_jets):
                                    pxi = np.array(px[i])[:max_particles]
                                    pyi = np.array(py[i])[:max_particles]
                                    pzi = np.array(pz[i])[:max_particles]
                                    
                                    # Filter empty jets or jets with very low pT
                                    if len(pxi) < 2: continue
                                    
                                    pti = np.sqrt(pxi**2 + pyi**2)
                                    if pti.sum() < 1e-3: continue

                                    p_mag = np.sqrt(pxi**2 + pyi**2 + pzi**2)
                                    denom = p_mag - pzi + 1e-12
                                    eta = 0.5 * np.log((p_mag + pzi) / denom)
                                    phi = np.arctan2(pyi, pxi)
                                    
                                    # Center the jet
                                    eta_avg = np.average(eta, weights=pti)
                                    phi_avg = np.average(phi, weights=pti)
                                    eta -= eta_avg
                                    phi -= phi_avg
                                    phi = (phi + np.pi) % (2 * np.pi) - np.pi

                                    X = np.stack([eta, phi], axis=1)
                                    w = pti / pti.sum()
                                    
                                    jets_X.append(X)
                                    jets_w.append(w)
                                    labels.append(lab)
                                break
                            
                            # Check for pt, eta, phi
                            pt_key = next((k for k in keys if k.endswith('pt')), None)
                            eta_key = next((k for k in keys if k.endswith('eta')), None)
                            phi_key = next((k for k in keys if k.endswith('phi')), None)

                            if pt_key and eta_key and phi_key:
                                pt = arrs[pt_key]
                                eta = arrs[eta_key]
                                phi = arrs[phi_key]
                                n_jets = min(len(pt), max_jets_per_file)
                                for i in range(n_jets):
                                    pti = np.array(pt[i])[:max_particles]
                                    etai = np.array(eta[i])[:max_particles]
                                    phii = np.array(phi[i])[:max_particles]
                                    
                                    if len(pti) < 2: continue
                                    if pti.sum() < 1e-3: continue

                                    # Center
                                    eta_avg = np.average(etai, weights=pti)
                                    phi_avg = np.average(phii, weights=pti)
                                    etai -= eta_avg
                                    phii -= phi_avg
                                    phii = (phii + np.pi) % (2 * np.pi) - np.pi

                                    X = np.stack([etai, phii], axis=1)
                                    w = pti / pti.sum()
                                    
                                    jets_X.append(X)
                                    jets_w.append(w)
                                    labels.append(lab)
                                break
                    except Exception as e:
                        continue
        except Exception as e:
            continue
            
    return jets_X, jets_w, np.array(labels)

# Load Data
jets_X, jets_w, labels = load_jets_from_inputs(max_jets_per_file=1000)
N = len(jets_X)
print(f"Total jets loaded: {N}")

if N == 0:
    raise ValueError("No jets loaded! Please check jetclass/ directory.")

# OT computation: Toggle between Sinkhorn (approximate) and EnergyFlow (exact)
USE_SINKHORN = True  # Default to Sinkhorn for speed
USE_GPU_IF_AVAILABLE = True

def compute_pairwise_matrix_gpu(jets_X, jets_w, max_particles=128):
    """
    Compute pairwise Sinkhorn distances using GeomLoss on GPU.
    """
    print("Preparing data for GPU...")
    N = len(jets_X)
    
    # Pad jets to fixed size
    X_pad = np.zeros((N, max_particles, 2), dtype=np.float32)
    w_pad = np.zeros((N, max_particles), dtype=np.float32)
    
    for i in range(N):
        n_p = min(len(jets_X[i]), max_particles)
        X_pad[i, :n_p] = jets_X[i][:n_p]
        w_pad[i, :n_p] = jets_w[i][:n_p]
        # Normalize weights again just in case
        w_sum = w_pad[i].sum()
        if w_sum > 0:
            w_pad[i] /= w_sum

    # Convert to Torch tensors
    X_dev = torch.tensor(X_pad).cuda()
    w_dev = torch.tensor(w_pad).cuda()
    
    # Define GeomLoss Sinkhorn
    # blur = 0.05 corresponds to the regularization parameter
    loss = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05)
    
    print("Computing pairwise matrix on GPU (this may take a moment)...")
    
    D_matrix = np.zeros((N, N))
    
    # Batch size for rows. 
    # We compute distances from 'batch_size' jets to ALL 'N' jets at once.
    # Memory usage is roughly (batch_size * N) * (max_particles^2) * 4 bytes.
    # For N=3000, P=128, batch_size=10 -> ~30k pairs -> ~2GB memory for kernel. Safe for T4.
    batch_size = 10 
    
    for i in range(0, N, batch_size):
        end = min(i + batch_size, N)
        B_curr = end - i
        
        # Batch of rows: (B, P, D)
        X_batch = X_dev[i:end]
        w_batch = w_dev[i:end]
        
        # We need to compute B_curr * N pairwise distances.
        # We construct flat tensors of size (B_curr * N, P, D)
        
        # Expand Source (i): Repeat each row N times
        # (B, P, D) -> (B * N, P, D)
        Xi_flat = X_batch.repeat_interleave(N, dim=0)
        wi_flat = w_batch.repeat_interleave(N, dim=0)
        
        # Expand Target (j): Tile the full set of N columns B times
        # (N, P, D) -> (B * N, P, D)
        Xj_flat = X_dev.repeat(B_curr, 1, 1)
        wj_flat = w_dev.repeat(B_curr, 1) # Fixed: repeat 2D tensor only once for dim 1
        
        # Compute loss
        # Output: (B * N, )
        # GeomLoss expects (Batch, P, D) inputs
        L_flat = loss(wi_flat, Xi_flat, wj_flat, Xj_flat)
        
        # Reshape to (B, N)
        L_mat = L_flat.view(B_curr, N)
        
        D_matrix[i:end, :] = L_mat.detach().cpu().numpy()
        
        if i % 100 == 0:
            print(f"Computed rows {i} to {end} of {N}...")
            torch.cuda.empty_cache()
            
    return D_matrix

def compute_ot_pair_cpu(i, j):
    Xi, wi = jets_X[i], jets_w[i]
    Xj, wj = jets_X[j], jets_w[j]
    
    if USE_SINKHORN:
        M = ot.dist(Xi, Xj, metric='euclidean')
        reg = 0.05 
        val = ot.sinkhorn2(wi, wj, M, reg)
        if isinstance(val, np.ndarray): val = val.item()
        return i, j, float(val)
    else:
        ev1 = np.column_stack((wi, Xi[:, 0], Xi[:, 1]))
        ev2 = np.column_stack((wj, Xj[:, 0], Xj[:, 1]))
        val = ef.emd.emd(ev1, ev2, R=10.0)
        return i, j, float(val)

# Main Computation Logic
if USE_SINKHORN and USE_GPU and USE_GPU_IF_AVAILABLE:
    print("Using GPU-accelerated Sinkhorn (GeomLoss)...")
    D = compute_pairwise_matrix_gpu(jets_X, jets_w)
    method_name = "Sinkhorn (GPU)"
else:
    method_name = "Sinkhorn (CPU)" if USE_SINKHORN else "Exact EMD (CPU)"
    print(f"Computing pairwise {method_name} distances using CPU parallel processing...")
    
    n_cores = min(8, multiprocessing.cpu_count())
    pairs = [(i, j) for i in range(N) for j in range(i+1, N)]
    print(f"Total pairs to compute: {len(pairs)}")

    results = Parallel(n_jobs=n_cores, verbose=5)(
        delayed(compute_ot_pair_cpu)(i, j) for i, j in pairs
    )

    D = np.zeros((N, N))
    for i, j, val in results:
        D[i, j] = val
        D[j, i] = val
    np.fill_diagonal(D, 0.0)

print("Pairwise distance matrix computed.")

# t-SNE embedding using precomputed distances
tsne = TSNE(n_components=2, metric='precomputed', init='random', random_state=42, perplexity=min(30, N-1))
emb = tsne.fit_transform(D)

plt.figure(figsize=(9, 7))
scatter = plt.scatter(emb[:, 0], emb[:, 1], c=labels, cmap='tab10', s=40, alpha=0.8)
plt.title(f't-SNE embedding of jets ({method_name})')
# Create legend handles manually since we might not have all classes
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=plt.cm.tab10(i), label=classes[i]['name'], markersize=10) for i in classes if i in np.unique(labels)]
plt.legend(handles=handles)
plt.grid(True, alpha=0.2)
plt.show()

# Scatter-thumbnail helper: draw a small jet (particle scatter) centered at (x0, y0)
def plot_jet_scatter(ax, x0, y0, X, w, scale=0.25, max_markersize=60, cmap='viridis', alpha=0.8):
    xs = x0 + X[:, 0] * scale
    ys = y0 + X[:, 1] * scale
    s = (w / (w.max() + 1e-12)) * max_markersize

No GPU detected. Using CPU.
Found 10 ROOT files. Loading...
Loading TTbar->bqq from TTBar_000.root...
Loading W->qq from WToQQ_000.root...
Loading Z->qq from ZToQQ_000.root...
Total jets loaded: 3000
Computing pairwise Sinkhorn (CPU) distances using CPU parallel processing...
Total pairs to compute: 4498500


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:    7.2s
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed:   10.0s
[Parallel(n_jobs=8)]: Done 146 tasks      | elapsed:   15.2s


KeyboardInterrupt: 