In [1]:
%run input/Format.ipynb
import ROOT as root
from array import array
root.gErrorIgnoreLevel = root.kFatal
%jsroot on
import numpy as np

  validate(nb)


Welcome to JupyROOT 6.30/06


In [2]:
file_path="input/"
file_names=["0version_Box_TPC_Au_Au_ZeroField_1mrad_aligned_10evt_9nSkip_75570-0_resid.root"]
file_names=['Au_Au_seeds_37thevt_66522-0_resid.root']
file_names=['Au_Au_seeds__8evts_7skip_75405-0_resid.root']

In [3]:
# Global mode: use radial triplets everywhere
TRIPLET_MODE = "radial"

def assert_radial():
    global TRIPLET_MODE
    if TRIPLET_MODE != "radial":
        print("Warning: Forcing TRIPLET_MODE='radial'")
        TRIPLET_MODE = "radial"

print("TRIPLET_MODE:", TRIPLET_MODE)


TRIPLET_MODE: radial


In [4]:
hists_read, hists_sim = [], []

# Open ROOT files and keep them open so TTrees remain accessible later
open_files = []  # keep TFile references alive
trees = {
    "cluster": [],
    "residual": [],
    "hit": [],
    "vertex": []
}

for iFile in range(len(file_names)):
    fpath = file_path + file_names[iFile]
    tfile = root.TFile.Open(fpath, "READ")
    if not tfile or tfile.IsZombie():
        print(f"Failed to open {fpath}")
        trees["cluster"].append(None)
        trees["residual"].append(None)
        trees["hit"].append(None)
        trees["vertex"].append(None)
        continue

    open_files.append(tfile)

    # Retrieve trees if available (names from your file structure)
    trees["cluster"].append(tfile.Get("clustertree"))
    trees["residual"].append(tfile.Get("residualtree"))
    trees["hit"].append(tfile.Get("hittree"))
    trees["vertex"].append(tfile.Get("vertextree"))

# Handy shorthand to the first file's trees
cluster_tree = trees["cluster"][0] if trees["cluster"] else None
residual_tree = trees["residual"][0] if trees["residual"] else None
hit_tree = trees["hit"][0] if trees["hit"] else None
vertex_tree = trees["vertex"][0] if trees["vertex"] else None

print(f"Loaded files: {len(open_files)}")
print("cluster_tree entries:", cluster_tree.GetEntries() if cluster_tree else 0)
print("residual_tree entries:", residual_tree.GetEntries() if residual_tree else 0)

Loaded files: 1
cluster_tree entries: 195129
residual_tree entries: 4616


In [5]:
# Quick check: list a few branches from the trees so they are clearly accessible
if cluster_tree:
    cluster_branches = [cluster_tree.GetListOfBranches().At(i).GetName() for i in range(min(10, cluster_tree.GetListOfBranches().GetEntries()))]
    print("cluster_tree branches (first 10):", cluster_branches)
if residual_tree:
    residual_branches = [residual_tree.GetListOfBranches().At(i).GetName() for i in range(min(10, residual_tree.GetListOfBranches().GetEntries()))]
    print("residual_tree branches (first 10):", residual_branches)

cluster_tree branches (first 10): ['run', 'segment', 'event', 'gl1bco', 'lx', 'lz', 'gx', 'gy', 'gz', 'phi']
residual_tree branches (first 10): ['run', 'segment', 'event', 'mbdcharge', 'mbdzvtx', 'firedTriggers', 'gl1BunchCrossing', 'trackid', 'tpcid', 'silid']


In [6]:

# Build a spatial index (R-tree-like) over cluster hits and provide triplet finder
# This keeps data in memory for fast neighbor queries in later cells.

# 1) Collect cluster points into a NumPy array [N,3]
cluster_points_cyl = None
cluster_entry_index = None  # maps point index -> TTree entry index

if cluster_tree:
    rs, phis, zs = [], [], []
    entry_idx = []
    n_entries = int(cluster_tree.GetEntries())
    for i in range(n_entries):
        cluster_tree.GetEntry(i)
        # Using branch names present in your file: gx, gy, gz    
        # Convert to cylindrical
        r = np.sqrt(cluster_tree.gx**2 + cluster_tree.gy**2)
        phi = np.arctan2(cluster_tree.gy, cluster_tree.gx)
        rs.append(float(r))
        phis.append(float(phi))
        zs.append(float(cluster_tree.gz))
        entry_idx.append(i)
    cluster_points_cyl = np.column_stack([rs, phis, zs]).astype(np.float32)
    cluster_entry_index = np.array(entry_idx, dtype=np.int64)
    print(f"cluster_points_cyl shape: {cluster_points_cyl.shape}")
else:
    print("No cluster_tree available; spatial index not built.")

cluster_points_cyl shape: (195129, 3)


In [7]:
# Build spatial index in (r, phi, z) space
# For radial searches, we'll sort by r and use binning in phi and z

spatial_index_cyl = None
_index_backend_cyl = None

if cluster_points_cyl is not None and len(cluster_points_cyl):
    try:
        # Build KDTree in cylindrical space
        from scipy.spatial import cKDTree as _KDTree
        spatial_index_cyl = _KDTree(cluster_points_cyl)
        _index_backend_cyl = "scipy.cKDTree"
    except Exception:
        try:
            from sklearn.neighbors import KDTree as _SKKDTree
            spatial_index_cyl = _SKKDTree(cluster_points_cyl)
            _index_backend_cyl = "sklearn.KDTree"
        except Exception:
            spatial_index_cyl = None
            _index_backend_cyl = "none"

print("Cylindrical spatial index backend:", _index_backend_cyl)


Cylindrical spatial index backend: scipy.cKDTree


In [8]:
# Build spatial index in (r, phi, z) space
# For radial searches, we'll sort by r and use binning in phi and z

spatial_index_cyl = None
_index_backend_cyl = None

if cluster_points_cyl is not None and len(cluster_points_cyl):
    try:
        # Build KDTree in cylindrical space
        from scipy.spatial import cKDTree as _KDTree
        spatial_index_cyl = _KDTree(cluster_points_cyl)
        _index_backend_cyl = "scipy.cKDTree"
    except Exception:
        try:
            from sklearn.neighbors import KDTree as _SKKDTree
            spatial_index_cyl = _SKKDTree(cluster_points_cyl)
            _index_backend_cyl = "sklearn.KDTree"
        except Exception:
            spatial_index_cyl = None
            _index_backend_cyl = "none"

print("Cylindrical spatial index backend:", _index_backend_cyl)


Cylindrical spatial index backend: scipy.cKDTree


In [9]:
# cluster_points_cyl: shape [N, 3] in (r, phi, z)
assert cluster_points_cyl is not None
assert cluster_points_cyl.shape[1] == 3

# unpack cylindrical coordinates
r_cyl   = cluster_points_cyl[:, 0].astype(float)
phi_raw = cluster_points_cyl[:, 1].astype(float)
z_cyl   = cluster_points_cyl[:, 2].astype(float)

# wrap phi into [-pi, pi] to be safe
phi_cyl = np.arctan2(np.sin(phi_raw), np.cos(phi_raw))

# polar angle relative to beam axis
theta_cyl = np.arctan2(r_cyl, z_cyl)

print("N hits (cyl):", len(cluster_points_cyl))
print("r   range:",   r_cyl.min(),   r_cyl.max())
print("phi range:",   phi_cyl.min(), phi_cyl.max())
print("theta range:", theta_cyl.min(), theta_cyl.max())


N hits (cyl): 195129
r   range: 1.8486348390579224 84.67854309082031
phi range: -3.141537666320801 3.1415791511535645
theta range: 0.14315392004395688 3.005435503532732


In [10]:

def angle_diff(a, b):
    """
    Minimal signed difference between angles a and b (radians),
    result in [-pi, pi]
    """
    d = a - b
    d = (d + np.pi) % (2*np.pi) - np.pi
    return d


In [11]:
# cylindrical: cluster_points_cyl[:, 0] = r, [:, 1] = phi, [:, 2] = z
r   = cluster_points_cyl[:, 0]
phi = cluster_points_cyl[:, 1]
z   = cluster_points_cyl[:, 2]

x = r * np.cos(phi)
y = r * np.sin(phi)

cluster_points_xyz = np.column_stack([x, y, z])


In [12]:
from scipy.spatial import cKDTree

_index_backend_xyz = "scipy.cKDTree"
spatial_index_xyz = cKDTree(cluster_points_xyz)


In [13]:
def neighbors_radius_cyl(i: int, radius: float):
    """
    Neighbours of hit i in *real 3D space* (x,y,z),
    using spatial_index_xyz built on cluster_points_xyz.
    'radius' is in the same units as r,z (e.g. cm).
    """
    assert spatial_index_xyz is not None, "spatial_index_xyz is None"

    if _index_backend_xyz == "scipy.cKDTree":
        return spatial_index_xyz.query_ball_point(cluster_points_xyz[i], r=radius)

    elif _index_backend_xyz == "sklearn.KDTree":
        return spatial_index_xyz.query_radius(cluster_points_xyz[i:i+1], r=radius)[0].tolist()

    else:
        diffs = cluster_points_xyz - cluster_points_xyz[i]
        d2 = np.sum(diffs**2, axis=1)
        return np.where(d2 <= radius * radius)[0].tolist()


In [14]:
# Make sure things exist
assert 'neighbors_radius_cyl' in globals(), "neighbors_radius_cyl not defined."
assert 'r_cyl' in globals() and 'phi_cyl' in globals() and 'theta_cyl' in globals(), \
    "r_cyl, phi_cyl, theta_cyl must be defined from cluster_points_cyl."

def grow_chain_from_seed(
    seed_idx,
    used_mask,
    search_radius=5.0,
    min_step_dr=0.2,
    max_step_dr=6.0,
    max_dphi_step=0.10,
    max_dtheta_step=0.10,
    max_delta_r=2.0,
    max_delta_dphi=0.03,
    max_delta_dtheta=0.03,
    max_chain_hits=260,
):
    """
    Build one chain starting at seed_idx, going inward in r_cyl,
    with smooth dphi, dtheta in cylindrical coordinates.

    Returns list of hit indices in order (outer -> inner).
    """

    N = len(cluster_points_cyl)

    chain = [seed_idx]
    used_mask[seed_idx] = True

    prev_dphi = None
    prev_dtheta = None
    prev_dr = None

    current_idx = seed_idx

    for step in range(max_chain_hits - 1):
        rc     = r_cyl[current_idx]
        phic   = phi_cyl[current_idx]
        thetac = theta_cyl[current_idx]

        # Neighbours in (r,phi,z) from the cylindrical KDTree
        neigh = neighbors_radius_cyl(current_idx, search_radius)
        neigh = [j for j in neigh if j != current_idx and not used_mask[j]]

        if not neigh:
            break

        best_score = None
        best_idx = None
        best_dphi = None
        best_dtheta = None
        best_dr = None

        for j in neigh:
            rj     = r_cyl[j]
            phij   = phi_cyl[j]
            thetaj = theta_cyl[j]

            dr = rj - rc           # inward = negative
            if dr >= -min_step_dr:   # not going inward enough
                continue
            if dr < -max_step_dr:    # jump too large
                continue

            dphi   = angle_diff(phij, phic)
            if abs(dr) > 0.5: dphi /= dr
            #dz  = z_cyl[j] - z_cyl[current_idx]
            #keff = 1
            #if abs(dz) > 0.5: keff = dz
            dtheta = (thetaj - thetac)

            # per-step cuts
            if abs(dphi) > max_dphi_step:
                continue
            if abs(dtheta) > max_dtheta_step:
                continue

            # smoothness vs previous step
            if prev_dphi is not None:
                if abs(dphi - prev_dphi) > max_delta_dphi:
                    continue
            if prev_dtheta is not None:
                if abs(dtheta - prev_dtheta) > max_delta_dtheta:
                    continue
            if prev_dr is not None:
                if abs(dr - prev_dr) > max_delta_r:
                    continue
            
            # score: prefer smoother evolution
            if prev_dr is None:
                score = abs(dphi - (prev_dphi or 0.0)) + abs(dtheta - (prev_dtheta or 0.0))
            else:
                score = (
                    abs(dphi - (prev_dphi or 0.0))
                    + abs(dtheta - (prev_dtheta or 0.0))
                    + 0.2 * abs(dr - prev_dr)
                )

            if (best_score is None) or (score < best_score):
                best_score = score
                best_idx = j
                best_dphi = dphi
                best_dtheta = dtheta
                best_dr = dr

        if best_idx is None:
            break

        # accept continuation

        chain.append(best_idx)
        used_mask[best_idx] = True

        prev_dphi = best_dphi
        prev_dtheta = best_dtheta
        prev_dr = best_dr
        current_idx = best_idx

    return chain


In [15]:
def build_smooth_inward_chains(
    search_radius=5.0,
    min_step_dr=0.2,
    max_step_dr=6.0,
    max_dphi_step=0.10,
    max_dtheta_step=0.10,
    max_delta_dphi=0.03,
    max_delta_dtheta=0.03,
    min_chain_hits_keep=20,
    max_chain_hits_keep=48,
    max_delta_r=2.0
):
    """
    Loop over hits, sorted by r (outermost first),
    and grow inward chains that satisfy smoothness constraints.

    Returns: list of chains (each is list of hit indices).
    """
    N = len(cluster_points_cyl)
    # sort seeds outer->inner
    seed_order = np.argsort(-r_cyl)  # descending radius

    used_mask = np.zeros(N, dtype=bool)
    chains = []

    for seed_idx in seed_order:
        if used_mask[seed_idx]:
            continue

        # Optionally: skip very inner hits as seeds
        if r_cyl[seed_idx] < 20.0:   # e.g. don't seed from MVTX area
            continue

        chain = grow_chain_from_seed(
            seed_idx,
            used_mask,
            search_radius=search_radius,
            min_step_dr=min_step_dr,
            max_step_dr=max_step_dr,
            max_dphi_step=max_dphi_step,
            max_dtheta_step=max_dtheta_step,
            max_delta_dphi=max_delta_dphi,
            max_delta_dtheta=max_delta_dtheta,
            max_chain_hits=max_chain_hits_keep,
            max_delta_r=max_delta_r
        )

        if len(chain) >= min_chain_hits_keep and len(chain) <= max_chain_hits_keep:
            chains.append(chain)

    print(f"Built {len(chains)} smooth inward chains in [{min_chain_hits_keep},{max_chain_hits_keep}] hits.")
    return chains

# Run it
chains = build_smooth_inward_chains(
    search_radius=3.0,
    min_step_dr=0,
    max_step_dr=3.0,
    max_dphi_step=0.4,
    max_dtheta_step=0.4,
    max_delta_dphi=0.005,
    max_delta_dtheta=0.01,
    min_chain_hits_keep=5,
    max_chain_hits_keep=248,
    max_delta_r=2.0
)

# Show lengths summary
lengths = [len(c) for c in chains]
if lengths:
    print("Chain length min / mean / max:",
          min(lengths), np.mean(lengths), max(lengths))
else:
    print("No chains found in the given range.")


Built 6634 smooth inward chains in [5,248] hits.
Chain length min / mean / max: 5 8.246909858305697 34


In [16]:
cluster_points = None
x = r_cyl * np.cos(phi_cyl)
y = r_cyl * np.sin(phi_cyl)
z = z_cyl
cluster_points = np.column_stack([x, y, z]).astype(np.float32)

In [17]:
# Create 3D ROOT plot of chains
if chains:
    c3d = root.TCanvas("c3d", "3D Chain Visualization", 1200, 900)
    
    # Create 3D histogram for the hit space
    x_min, x_max = cluster_points[:, 0].min(), cluster_points[:, 0].max()
    y_min, y_max = cluster_points[:, 1].min(), cluster_points[:, 1].max()
    z_min, z_max = cluster_points[:, 2].min(), cluster_points[:, 2].max()
    
    h3d_frame = root.TH3F("h3d_frame", "3D Chains;X [cm];Y [cm];Z [cm]",
                          1, x_min-10, x_max+10,
                          1, y_min-10, y_max+10,
                          1, z_min-10, z_max+10)
    h3d_frame.SetStats(0)
    h3d_frame.Draw()
    
    # Draw all hits as small markers (gray background)
    graph_all_hits = root.TGraph2D(len(cluster_points))
    for i in range(len(cluster_points)):
        graph_all_hits.SetPoint(i, 
                                cluster_points[i, 0], 
                                cluster_points[i, 1], 
                                cluster_points[i, 2])
    graph_all_hits.SetMarkerStyle(7)  # small dots
    graph_all_hits.SetMarkerColor(root.kGray)
    graph_all_hits.Draw("P0 SAME")
    
    # Draw chains as colored polylines
    polylines = []
    colors = [root.kRed, root.kRed+3, root.kGreen+2, root.kMagenta, root.kOrange, 
              root.kCyan, root.kViolet, root.kSpring, root.kTeal, root.kPink]
    
    max_chains_to_draw = min(5000, len(chains))
    for i in range(max_chains_to_draw):
        chain = chains[i]
    
        # --- polyline in chain order ---
        pl = root.TPolyLine3D(len(chain))
        for j, idx in enumerate(chain):
            x, y, z = cluster_points[idx]
            pl.SetPoint(j, float(x), float(y), float(z))
    
        pl.SetLineColor(colors[i % len(colors)])
        pl.SetLineWidth(4)
        pl.Draw()
        polylines.append(pl)
    
        # --- markers in chain order ---
        graph_chain = root.TGraph2D(len(chain))
        for j, idx in enumerate(chain):
            x, y, z = cluster_points[idx]
            graph_chain.SetPoint(j, float(x), float(y), float(z))
    
        graph_chain.SetMarkerStyle(20)
        graph_chain.SetMarkerSize(0.8)
        graph_chain.SetMarkerColor(colors[i % len(colors)])
        graph_chain.Draw("P0 SAME")
        polylines.append(graph_chain)

    
    #c3d.Draw()
    
    print(f"3D ROOT plot: Showing {max_chains_to_draw} chains out of {len(chains)} total")
    print(f"Total hits displayed: {len(cluster_points)}")
else:
    print("No chains to plot")


3D ROOT plot: Showing 5000 chains out of 6634 total
Total hits displayed: 195129


In [18]:
do_continue = True

In [19]:
rack_cluster_points = None
track_cluster_entry_index = None  # maps point index -> TTree entry (track) index
N_of_tracks_standard = 0
N_clusters_ontrack_standard = 0
if residual_tree:
    txs, tys, tzs = [], [], []
    t_entry_idx = []
    n_entries = int(residual_tree.GetEntries())
    print(f"Building track cluster index from {n_entries} residual_tree entries...")
    for i in range(n_entries):
        residual_tree.GetEntry(i)

        nclus = len(residual_tree.clusgx)
       
        # Sanity check (optional)
        # assert nclus == len(residual_tree.clusgy) == len(residual_tree.clusgz)
        N_of_tracks_standard+=1
        for j in range(nclus):
            # collect positions

            #N_clusters_ontrack_standard+=1
            txs.append(float(residual_tree.clusgx[j]))
            tys.append(float(residual_tree.clusgy[j]))
            tzs.append(float(residual_tree.clusgz[j]))
            rtclus = (residual_tree.clusgx[j]**2 + residual_tree.clusgy[j]**2)**0.5
            if rtclus >20 and rtclus<80:
                N_clusters_ontrack_standard +=1

            # remember which track (tree entry) this cluster belongs to
            t_entry_idx.append(i)

    # shape: [N_total_clusters, 3]
    track_cluster_points = np.column_stack([txs, tys, tzs]).astype(np.float32)
    # shape: [N_total_clusters], values are tree entry indices (track ids)
    track_cluster_entry_index = np.array(t_entry_idx, dtype=np.int64)
    print(f"Number of tracks {N_of_tracks_standard}")
    print(f"Number of clusters on tracks {N_clusters_ontrack_standard}")
    print(f"track_cluster_points shape: {track_cluster_points.shape}")
    print(f"track_cluster_entry_index shape: {track_cluster_entry_index.shape}")
else:
    print("No residual_tree available; spatial index not built.")

Building track cluster index from 4616 residual_tree entries...
Number of tracks 4616
Number of clusters on tracks 179171
track_cluster_points shape: (197425, 3)
track_cluster_entry_index shape: (197425,)


In [20]:
# removing clusters with non-finite values and those not matched within tolerance
if cluster_points is not None and track_cluster_points is not None:
    print("Before cleaning: cluster_points:", cluster_points.shape, " track_cluster_points:", track_cluster_points.shape)

    # 1) Drop any non-finite rows in track_cluster_points (NaN / inf cause ROOT TGraph2D errors)
    finite_mask = np.all(np.isfinite(track_cluster_points), axis=1)
    if not np.all(finite_mask):
        removed = (~finite_mask).sum()
        print(f"Removing {removed} non-finite track cluster points.")
        track_cluster_points = track_cluster_points[finite_mask]
        track_cluster_entry_index = track_cluster_entry_index[finite_mask]

    # Also ensure cluster_points are finite (should already be, but be safe)
    if not np.all(np.isfinite(cluster_points)):
        cf_mask = np.all(np.isfinite(cluster_points), axis=1)
        print(f"Warning: {(~cf_mask).sum()} non-finite cluster_points rows ignored.")
        cluster_points_clean = cluster_points[cf_mask]
    else:
        cluster_points_clean = cluster_points

    # 2) Build KDTree once
    tree = spatial_index_xyz if 'spatial_index_xyz' in globals() and spatial_index_xyz is not None else cKDTree(cluster_points_clean)

    tolerance = 1.0  # adjust as needed

    # 3) Vectorized nearest-neighbour query
    dists, nn_idx = tree.query(track_cluster_points, distance_upper_bound=tolerance)

    keep_mask = np.isfinite(dists) & (dists != float('inf'))
    kept = int(keep_mask.sum())
    dropped = len(keep_mask) - kept
    if dropped:
        print(f"Dropping {dropped} track clusters without match within tolerance {tolerance}.")

    track_cluster_points = track_cluster_points[keep_mask]
    track_cluster_entry_index = track_cluster_entry_index[keep_mask]

    print("After cleaning: cluster_points:", cluster_points.shape, " track_cluster_points:", track_cluster_points.shape)

Before cleaning: cluster_points: (195129, 3)  track_cluster_points: (197425, 3)
Removing 3 non-finite track cluster points.
Dropping 96501 track clusters without match within tolerance 1.0.
After cleaning: cluster_points: (195129, 3)  track_cluster_points: (100921, 3)


In [21]:
#cleaning vtx hits from track_cluster_points so, all cluster with r< 25 cm are removed
if track_cluster_points is not None:
    print("Before cleaning vtx hits: track_cluster_points:", track_cluster_points.shape)
    r_track_clusters = np.sqrt(track_cluster_points[:,0]**2 + track_cluster_points[:,1]**2)
    vtx_mask = r_track_clusters >= 25.0
    removed_vtx = (~vtx_mask).sum()
    if removed_vtx:
        print(f"Removing {removed_vtx} track cluster points with r < 25 cm (vtx hits).")
        track_cluster_points = track_cluster_points[vtx_mask]
        track_cluster_entry_index = track_cluster_entry_index[vtx_mask]
    print("After cleaning vtx hits: track_cluster_points:", track_cluster_points.shape)

Before cleaning vtx hits: track_cluster_points: (100921, 3)
Removing 18162 track cluster points with r < 25 cm (vtx hits).
After cleaning vtx hits: track_cluster_points: (82759, 3)


In [22]:
c_clust_standard = root.TCanvas("c_clust_standard", "All clusters and found on tracks", 1200, 900)

# Frame
x_min, x_max = cluster_points[:, 0].min(), cluster_points[:, 0].max()
y_min, y_max = cluster_points[:, 1].min(), cluster_points[:, 1].max()
z_min, z_max = cluster_points[:, 2].min(), cluster_points[:, 2].max()

h3d_frame = root.TH3F("h3d_frame", "3D Chains;X [cm];Y [cm];Z [cm]",
                        1, x_min-10, x_max+10,
                        1, y_min-10, y_max+10,
                        1, z_min-10, z_max+10)
h3d_frame.SetStats(0)
h3d_frame.Draw()

g_track_clusters = root.TGraph2D(len(track_cluster_points))
for i, (x, y, z) in enumerate(track_cluster_points):
    g_track_clusters.SetPoint(i, float(x), float(y), float(z))
g_track_clusters.SetMarkerStyle(20)
g_track_clusters.SetMarkerSize(0.6)
g_track_clusters.SetMarkerColor(root.kRed)
g_track_clusters.Draw("P0 SAME")


g_all_clusters = root.TGraph2D(len(cluster_points))
for i, (x, y, z) in enumerate(cluster_points):
    g_all_clusters.SetPoint(i, float(x), float(y), float(z))
g_all_clusters.SetMarkerStyle(7)
g_all_clusters.SetMarkerSize(0.1)
g_all_clusters.SetMarkerColor(root.kGray+1)
#g_all_clusters.Draw("P0 SAME")
for i in range(max_chains_to_draw):
    chain = chains[i]

    # --- polyline in chain order ---
    pl = root.TPolyLine3D(len(chain))
    for j, idx in enumerate(chain):
        x, y, z = cluster_points[idx]
        pl.SetPoint(j, float(x), float(y), float(z))

    pl.SetLineColor(colors[i % len(colors)])
    pl.SetLineWidth(4)
    pl.Draw()
    polylines.append(pl)

    # --- markers in chain order ---
    graph_chain = root.TGraph2D(len(chain))
    for j, idx in enumerate(chain):
        x, y, z = cluster_points[idx]
        graph_chain.SetPoint(j, float(x), float(y), float(z))

    graph_chain.SetMarkerStyle(20)
    graph_chain.SetMarkerSize(0.8)
    graph_chain.SetMarkerColor(colors[i % len(colors)])
    #graph_chain.Draw("P0 SAME")
    polylines.append(graph_chain)


# keep Python references so ROOT objects are not deleted
c_clust_standard._objs = [h3d_frame,  g_track_clusters, g_all_clusters] 
if do_continue: c_clust_standard.Draw()

In [23]:
print(f"Total number of clusters {len(cluster_points)}.")
N_clusters_chain = sum([len(c) for c in chains])
print(" ")
print(f"Number of clusters found on  tracks by standard algorithm {len(track_cluster_points)}.")
print(f"Number of clusters found on  chains {N_clusters_chain}.")
print(" ")
print(f"Number of  tracks by standard algorithm {N_of_tracks_standard}.")
print(f"Number of chains found {len(chains)}.")

Total number of clusters 195129.
 
Number of clusters found on  tracks by standard algorithm 82759.
Number of clusters found on  chains 54710.
 
Number of  tracks by standard algorithm 4616.
Number of chains found 6634.


In [24]:
outf = root.TFile("output/chains_display_central.root", "RECREATE")
c3d.Write("c3d")   # serializes the canvas + its primitives
outf.Close()

In [25]:
do_continue = False

In [26]:
# 3D ROOT plot of hits NOT on any chain
import numpy as np

assert 'cluster_points' in globals(), "cluster_points (Nx3 XYZ) not found"
assert 'chains' in globals(), "chains list not found"

N = len(cluster_points)

# Build a boolean mask of hits that are on any chain
on_track = np.zeros(N, dtype=bool)
for ch in chains:
    on_track[np.asarray(ch, dtype=int)] = True

off_idx = np.where(~on_track)[0]
num_off = len(off_idx)

print(f"Total hits: {N}")
print(f"Hits on chains: {int(on_track.sum())}")
print(f"Hits NOT on any chain: {num_off}")

if num_off > 0 and False:
    # Canvas & 3D frame
    c3d_off = root.TCanvas("c3d_off", "3D: Hits NOT on chains", 1200, 900)

    x_min, x_max = cluster_points[:, 0].min(), cluster_points[:, 0].max()
    y_min, y_max = cluster_points[:, 1].min(), cluster_points[:, 1].max()
    z_min, z_max = cluster_points[:, 2].min(), cluster_points[:, 2].max()

    h3d_frame_off = root.TH3F("h3d_frame_off",
                              "3D Hits NOT on Chains;X [cm];Y [cm];Z [cm]",
                              1, x_min-10, x_max+10,
                              1, y_min-10, y_max+10,
                              1, z_min-10, z_max+10)
    h3d_frame_off.SetStats(0)
    h3d_frame_off.Draw()

    # Draw only off-track hits as tiny gray dots
    g_off = root.TGraph2D(num_off)
    for i, idx in enumerate(off_idx):
        x, y, z = cluster_points[idx]
        g_off.SetPoint(i, float(x), float(y), float(z))
    g_off.SetMarkerStyle(7)              # tiny dots
    g_off.SetMarkerColor(root.kGray+1)   # slightly brighter gray
    g_off.Draw("P0 SAME")

    c3d_off.Draw()
else:
    print("All hits are assigned to chains — nothing to draw.")


Total hits: 195129
Hits on chains: 54710
Hits NOT on any chain: 140419
All hits are assigned to chains — nothing to draw.


In [27]:
# ================= FLEXIBLE CYL-SPACE CHAIN FINDER (NO GLOBALS REQUIRED) =================
import numpy as np

# --- 0) A light container for the cylindrical data & a neighbor function
class CylSpace:
    def __init__(self, r, phi, theta, pts_rphiz, neighbors_fn):
        self.r = r
        self.phi = phi
        self.theta = theta
        self.pts = pts_rphiz    # shape [N,3] in (r,phi,z)
        self.neighbors = neighbors_fn  # callable: neighbors(i, radius)->List[int]

def angle_diff(a, b):
    d = a - b
    d = (d + np.pi) % (2*np.pi) - np.pi
    return d

# --- 1) Build a NEW KDTree + neighbor fn for ANY (r,phi,z) array (does not touch globals)
def make_cyl_index(pts_rphiz):
    """
    pts_rphiz: np.ndarray [N,3] in (r,phi,z). Returns (neighbors_fn, index_backend).
    neighbors_fn(i, radius) returns indices in [0..N-1] in subset coordinates.
    """
    try:
        from scipy.spatial import cKDTree as _KDTree
        index = _KDTree(pts_rphiz)
        backend = "scipy.cKDTree"

        def neighbors_fn(i, radius):
            return index.query_ball_point(pts_rphiz[i], r=radius)

    except Exception:
        from sklearn.neighbors import KDTree as _SKKDTree
        index = _SKKDTree(pts_rphiz)
        backend = "sklearn.KDTree"

        def neighbors_fn(i, radius):
            return index.query_radius(pts_rphiz[i:i+1], r=radius)[0].tolist()

    return neighbors_fn, backend

# --- 2) A function to get a CylSpace from your current GLOBALS (keeps old behavior working)
def cylspace_from_globals():
    # assumes you already defined these in earlier cells
    assert 'r_cyl' in globals() and 'phi_cyl' in globals() and 'theta_cyl' in globals()
    assert 'cluster_points_cyl' in globals()
    assert 'neighbors_radius_cyl' in globals()
    return CylSpace(r_cyl, phi_cyl, theta_cyl, cluster_points_cyl, neighbors_radius_cyl)

# --- 3) Your grow_chain_from_seed, now taking an optional CylSpace (defaults to globals)
def grow_chain_from_seed(
    seed_idx,
    used_mask,
    search_radius=5.0,
    min_step_dr=0.2,
    max_step_dr=6.0,
    max_dphi_step=0.10,
    max_dtheta_step=0.10,
    max_delta_r=2.0,
    max_delta_dphi=0.03,
    max_delta_dtheta=0.03,
    max_chain_hits=260,
    cyl: CylSpace | None = None,  # <--- NEW
):
    """
    Grow one chain inward in r with smooth dphi/dtheta.
    If 'cyl' is None, uses current globals (backwards-compatible).
    """
    if cyl is None:
        cyl = cylspace_from_globals()

    r_arr, phi_arr, theta_arr = cyl.r, cyl.phi, cyl.theta
    neighbors = cyl.neighbors

    chain = [seed_idx]
    used_mask[seed_idx] = True

    prev_dphi = None
    prev_dtheta = None
    prev_dr = None
    current_idx = seed_idx

    for _ in range(max_chain_hits - 1):
        rc     = r_arr[current_idx]
        phic   = phi_arr[current_idx]
        thetac = theta_arr[current_idx]

        neigh = neighbors(current_idx, search_radius)
        neigh = [j for j in neigh if j != current_idx and not used_mask[j]]
        if not neigh:
            break

        best_idx = None
        best_score = None
        best_dphi = best_dtheta = best_dr = None

        for j in neigh:
            rj, phij, thetaj = r_arr[j], phi_arr[j], theta_arr[j]

            dr = rj - rc                 # inward => negative
            if dr >= -min_step_dr:       # not inward enough
                continue
            if dr < -max_step_dr:        # huge inward jump
                continue

            dphi   = angle_diff(phij, phic)
            # Optional normalization by dr for better helix stability:
            if abs(dr) > 0.5:
                dphi = dphi / dr

            dtheta = thetaj - thetac

            # per-step limits
            if abs(dphi) > max_dphi_step:     continue
            if abs(dtheta) > max_dtheta_step: continue

            # smoothness vs previous step
            if prev_dphi is not None and abs(dphi - prev_dphi) > max_delta_dphi:
                continue
            if prev_dtheta is not None and abs(dtheta - prev_dtheta) > max_delta_dtheta:
                continue
            if prev_dr is not None and abs(dr - prev_dr) > max_delta_r:
                continue

            # score: prefer smoother evolution
            if prev_dr is None:
                score = abs(dphi - (prev_dphi or 0.0)) + abs(dtheta - (prev_dtheta or 0.0))
            else:
                score = abs(dphi - (prev_dphi or 0.0)) \
                        + abs(dtheta - (prev_dtheta or 0.0)) \
                        + 0.2 * abs(dr - prev_dr)

            if best_score is None or score < best_score:
                best_score = score
                best_idx   = j
                best_dphi, best_dtheta, best_dr = dphi, dtheta, dr

        if best_idx is None:
            break

        chain.append(best_idx)
        used_mask[best_idx] = True
        prev_dphi, prev_dtheta, prev_dr = best_dphi, best_dtheta, best_dr
        current_idx = best_idx

    return chain

# --- 4) Build chains, also optionally on a provided CylSpace (defaults to globals)
def build_smooth_inward_chains(
    search_radius=5.0,
    min_step_dr=0.2,
    max_step_dr=6.0,
    max_dphi_step=0.10,
    max_dtheta_step=0.10,
    max_delta_dphi=0.03,
    max_delta_dtheta=0.03,
    min_chain_hits_keep=20,
    max_chain_hits_keep=48,
    max_delta_r=2.0,
    r_seed_min=20.0,
    require_unique=True,
    cyl: CylSpace | None = None,   # <--- NEW
):
    """
    Loop seeds outer->inner and grow inward chains with smoothness constraints.
    If 'cyl' is None, uses current globals. Otherwise only uses 'cyl'.
    """
    if cyl is None:
        cyl = cylspace_from_globals()

    r_arr = cyl.r
    N = len(r_arr)
    seeds = np.argsort(-r_arr)  # outer -> inner
    used_mask = np.zeros(N, dtype=bool)
    out = []

    for seed_idx in seeds:
        if r_arr[seed_idx] < r_seed_min:
            continue
        if require_unique and used_mask[seed_idx]:
            continue

        chain = grow_chain_from_seed(
            seed_idx,
            used_mask,
            search_radius=search_radius,
            min_step_dr=min_step_dr,
            max_step_dr=max_step_dr,
            max_dphi_step=max_dphi_step,
            max_dtheta_step=max_dtheta_step,
            max_delta_r=max_delta_r,
            max_delta_dphi=max_delta_dphi,
            max_delta_dtheta=max_delta_dtheta,
            max_chain_hits=max_chain_hits_keep,
            cyl=cyl,   # <--- critical: pass the same CylSpace
        )

        if min_chain_hits_keep <= len(chain) <= max_chain_hits_keep:
            out.append(chain)

    print(f"Built {len(out)} smooth inward chains in [{min_chain_hits_keep},{max_chain_hits_keep}] hits.")
    return out

# --- 5) Example: SECOND PASS on leftover hits with a brand-new KDTree, NO globals touch ---

def run_second_pass_on_leftovers(chains_first_pass,
                                 cluster_points,          # (N,3) XYZ (only for plotting)
                                 cluster_points_cyl_full, # (N,3) (r,phi,z)
                                 r_full, phi_full, theta_full,
                                 looser_kwargs=None):
    """
    Create a new KDTree on leftover hits only, run chain finder there, and
    return (chains_pass2_global_indices, off_idx_global, cyl_subset_used).
    """
    N = len(cluster_points_cyl_full)
    on_track = np.zeros(N, dtype=bool)
    for ch in chains_first_pass:
        on_track[np.asarray(ch, dtype=int)] = True

    off_idx = np.where(~on_track)[0]
    M = len(off_idx)
    print(f"[2nd pass] leftovers: {M}/{N}")

    if M == 0:
        return [], off_idx, None

    # subset arrays (cylindrical)
    pts_sub   = cluster_points_cyl_full[off_idx].astype(np.float32)
    r_sub     = r_full[off_idx].astype(float)
    phi_raw   = phi_full[off_idx].astype(float)
    z_sub     = pts_sub[:,2].astype(float)
    phi_sub   = np.arctan2(np.sin(phi_raw), np.cos(phi_raw))
    theta_sub = np.arctan2(r_sub, z_sub)

    neighbors_sub, backend = make_cyl_index(pts_sub)
    cyl_sub = CylSpace(r_sub, phi_sub, theta_sub, pts_sub, neighbors_sub)
    print(f"[2nd pass] subset KDTree backend: {backend}")

    # run with looser parameters on the subset (no globals touched)
    kwargs = dict(
        search_radius=3.0,
        min_step_dr=0.0,
        max_step_dr=3.0,
        max_dphi_step=0.4,
        max_dtheta_step=0.4,
        max_delta_dphi=0.01,
        max_delta_dtheta=0.01,
        min_chain_hits_keep=5,
        max_chain_hits_keep=248,
        max_delta_r=2.0,
        r_seed_min=0.0,
        require_unique=True,
        cyl=cyl_sub
    )
    if looser_kwargs:
        kwargs.update(looser_kwargs)

    chains_sub = build_smooth_inward_chains(**kwargs)

    # map subset indices back to GLOBAL indices
    chains_pass2_global = [[int(off_idx[j]) for j in ch] for ch in chains_sub]
    print(f"[2nd pass] new chains: {len(chains_pass2_global)}")
    return chains_pass2_global, off_idx, cyl_sub


In [28]:
# First pass already done:
#   chains  (list of lists of global indices)
#   cluster_points (XYZ) ; cluster_points_cyl (r,phi,z)
#   r_cyl, phi_cyl, theta_cyl (full arrays)
if do_continue:
    chains_pass2, off_idx, cyl_sub = run_second_pass_on_leftovers(
        chains_first_pass=chains,
        cluster_points=cluster_points,
        cluster_points_cyl_full=cluster_points_cyl,
        r_full=r_cyl, phi_full=phi_cyl, theta_full=theta_cyl,
        looser_kwargs=dict(   # tweak if you like
            search_radius=3.0,
            min_step_dr=-3,
            max_step_dr=3.0,
            max_dphi_step=0.05,
            max_dtheta_step=0.4,
            max_delta_dphi=0.01,
            max_delta_dtheta=0.01,
            min_chain_hits_keep=5,
            max_chain_hits_keep=248,
            max_delta_r=3.0
        )
    )

# Now chains_pass2 are GLOBAL hit indices for just the leftover hits.
# Plot ONLY leftovers + new chains (your ROOT plotting code can reuse these two objects).


In [29]:
# ================== DRAW: pass-2 chains on leftover hits, then residual leftovers ==================
import numpy as np
if do_continue:
    assert 'cluster_points' in globals(), "cluster_points (XYZ) missing"
    assert 'chains_pass2' in globals(), "chains_pass2 not found (run second pass first)"
    assert 'off_idx' in globals(), "off_idx from second pass not found"
    
    N_xyz = len(cluster_points)
    colors = [root.kRed, root.kRed+3, root.kGreen+2, root.kMagenta, root.kOrange,
              root.kCyan, root.kViolet, root.kSpring, root.kTeal, root.kPink]
    
    # ---------- 1) PASS-2 VIEW: leftover hits (from pass-1) + new chains ----------
    if len(off_idx) == 0:
        print("[draw] No leftover hits from pass-1 to draw.")
    else:
        suffix = f"_{np.random.randint(1e9)}"
        c_pass2 = root.TCanvas(f"c_pass2{suffix}", "3D: Pass-2 (leftovers + new chains)", 1200, 900)
    
        # Use only in-bounds indices for XYZ
        off_idx_xyz = off_idx[off_idx < N_xyz]
        xyz_off = cluster_points[off_idx_xyz]
    
        x_min, x_max = xyz_off[:,0].min(), xyz_off[:,0].max()
        y_min, y_max = xyz_off[:,1].min(), xyz_off[:,1].max()
        z_min, z_max = xyz_off[:,2].min(), xyz_off[:,2].max()
    
        h3 = root.TH3F(f"h3_pass2{suffix}",
                       "3D Pass-2: leftover hits + new chains;X [cm];Y [cm];Z [cm]",
                       1, x_min-10, x_max+10,
                       1, y_min-10, y_max+10,
                       1, z_min-10, z_max+10)
        h3.SetStats(0)
        h3.Draw()
    
        # Background: leftover hits (from pass-1)
        g_bg = root.TGraph2D(len(off_idx_xyz))
        for i, idx in enumerate(off_idx_xyz):
            X, Y, Z = cluster_points[idx]
            g_bg.SetPoint(i, float(X), float(Y), float(Z))
        g_bg.SetMarkerStyle(7)
        g_bg.SetMarkerColor(root.kGray+1)
        g_bg.Draw("P0 SAME")
    
        # Overlay: new chains (pass-2)
        draw_max = min(5000, len(chains_pass2))
        for i in range(draw_max):
            ch = chains_pass2[i]
            col = colors[i % len(colors)]
    
            # Polyline
            pl = root.TPolyLine3D(len(ch))
            k = 0
            for idx in ch:
                if 0 <= idx < N_xyz:
                    X, Y, Z = cluster_points[idx]
                    pl.SetPoint(k, float(X), float(Y), float(Z))
                    k += 1
            pl.SetLineColor(col)
            pl.SetLineWidth(4)
            pl.Draw()
    
            # Markers
            gc = root.TGraph2D(len(ch))
            m = 0
            for idx in ch:
                if 0 <= idx < N_xyz:
                    X, Y, Z = cluster_points[idx]
                    gc.SetPoint(m, float(X), float(Y), float(Z))
                    m += 1
            gc.SetMarkerStyle(20)
            gc.SetMarkerSize(0.8)
            gc.SetMarkerColor(col)
            gc.Draw("P0 SAME")
    
        c_pass2.Draw()
        print(f"[draw] Pass-2: drew {draw_max} chains over {len(off_idx_xyz)} leftover hits.")
    
    # ---------- 2) RESIDUAL LEFTOVERS: hits still not on any chain after pass-2 ----------
    # Build on-track mask that includes pass-2 chains
    on_track_after_p2 = np.zeros(N_xyz, dtype=bool)
    # If you want to include pass-1 chains as well, union them here:
    # for ch in chains: on_track_after_p2[np.asarray(ch, int)] = True
    for ch in chains_pass2:
        valid = [idx for idx in ch if 0 <= idx < N_xyz]
        on_track_after_p2[np.asarray(valid, dtype=int)] = True
    
    # Residual = leftover from pass-1 minus those used by pass-2
    residual_mask = np.zeros(N_xyz, dtype=bool)
    residual_mask[off_idx[off_idx < N_xyz]] = True
    residual_mask &= ~on_track_after_p2
    residual_idx = np.where(residual_mask)[0]
    
    print(f"[draw] Residual leftover hits after pass-2: {len(residual_idx)}")
    
    if len(residual_idx) > 0 and False:
        suffix2 = f"_{np.random.randint(1e9)}"
        c_resid = root.TCanvas(f"c_resid{suffix2}", "3D: Residual leftovers after pass-2", 1200, 900)
    
        xyz_res = cluster_points[residual_idx]
        xmi, xma = xyz_res[:,0].min(), xyz_res[:,0].max()
        ymi, yma = xyz_res[:,1].min(), xyz_res[:,1].max()
        zmi, zma = xyz_res[:,2].min(), xyz_res[:,2].max()
    
        h3r = root.TH3F(f"h3_resid{suffix2}",
                        "3D Residual Leftovers After Pass-2;X [cm];Y [cm];Z [cm]",
                        1, xmi-10, xma+10,
                        1, ymi-10, yma+10,
                        1, zmi-10, zma+10)
        h3r.SetStats(0)
        h3r.Draw()
    
        g_res = root.TGraph2D(len(residual_idx))
        for i, idx in enumerate(residual_idx):
            X, Y, Z = cluster_points[idx]
            g_res.SetPoint(i, float(X), float(Y), float(Z))
        g_res.SetMarkerStyle(7)
        g_res.SetMarkerColor(root.kGray+2)
        g_res.Draw("P0 SAME")
    
        c_resid.Draw()
    else:
        print("[draw] No residual leftovers after pass-2.")
    

In [30]:
outf = root.TFile("output/chains_display_central.root", "RECREATE")
c3d.Write("c3d")   # serializes the canvas + its primitives
outf.Close()

In [31]:
import numpy as np

def check_chain_max_jump(chain, max_dist, points):
    """
    chain  : list[int] of hit indices (in order)
    max_dist : float, maximum allowed step size (same units as points, e.g. cm)
    points: np.ndarray of shape (N, dim) with coordinates (xyz or rφz)

    Returns:
      ok         : bool, True if all steps <= max_dist
      bad_jumps  : list of (i_prev, i_cur, d) for each violating step
    """
    bad_jumps = []
    if len(chain) < 2:
        return True, bad_jumps  # trivial chain

    for i_prev, i_cur in zip(chain[:-1], chain[1:]):
        p_prev = points[i_prev]
        p_cur  = points[i_cur]
        d = np.linalg.norm(p_cur - p_prev)
        if d > max_dist:
            bad_jumps.append((i_prev, i_cur, d))

    ok = (len(bad_jumps) == 0)
    return ok, bad_jumps


In [32]:
ok, bad_jumps = check_chain_max_jump(chain, max_dist=3.0, points=cluster_points_xyz)
if not ok:
    print("Chain has jumps > 3 cm:")
    for i_prev, i_cur, d in bad_jumps:
        print(f"  {i_prev} -> {i_cur}: d = {d:.3f} cm")


In [33]:
def truncate_chain_at_jump(chain, max_dist, points):
    """
    Return a shortened chain where we stop at the first step > max_dist.
    """
    if len(chain) < 2:
        return chain[:]

    new_chain = [chain[0]]
    for i_prev, i_cur in zip(chain[:-1], chain[1:]):
        d = np.linalg.norm(points[i_cur] - points[i_prev])
        if d > max_dist:
            break
        new_chain.append(i_cur)
    return new_chain
