In [None]:
#%run input/Format.ipynb
import ROOT as root
from array import array
root.gErrorIgnoreLevel = root.kFatal
%jsroot on
import numpy as np

In [None]:
file_path="/Users/mitrankova/Jupyter/PatternRecognition/input/"
#file_names=["0version_Box_TPC_Au_Au_ZeroField_1mrad_aligned_10evt_9nSkip_75570-0_resid.root"]
file_names=['Au_Au_seeds_37thevt_66522-0_resid.root']
#file_names=['Au_Au_seeds__8evts_7skip_75405-0_resid.root']

In [None]:
# Global mode: use radial triplets everywhere
TRIPLET_MODE = "radial"

def assert_radial():
    global TRIPLET_MODE
    if TRIPLET_MODE != "radial":
        print("Warning: Forcing TRIPLET_MODE='radial'")
        TRIPLET_MODE = "radial"

print("TRIPLET_MODE:", TRIPLET_MODE)


In [None]:
hists_read, hists_sim = [], []

# Open ROOT files and keep them open so TTrees remain accessible later
open_files = []  # keep TFile references alive
trees = {
    "cluster": [],
    "residual": [],
    "hit": [],
    "vertex": []
}

for iFile in range(len(file_names)):
    fpath = file_path + file_names[iFile]
    tfile = root.TFile.Open(fpath, "READ")
    if not tfile or tfile.IsZombie():
        print(f"Failed to open {fpath}")
        trees["cluster"].append(None)
        trees["residual"].append(None)
        trees["hit"].append(None)
        trees["vertex"].append(None)
        continue

    open_files.append(tfile)

    # Retrieve trees if available (names from your file structure)
    trees["cluster"].append(tfile.Get("clustertree"))
    trees["residual"].append(tfile.Get("residualtree"))
    trees["hit"].append(tfile.Get("hittree"))
    trees["vertex"].append(tfile.Get("vertextree"))

# Handy shorthand to the first file's trees
cluster_tree = trees["cluster"][0] if trees["cluster"] else None
residual_tree = trees["residual"][0] if trees["residual"] else None
hit_tree = trees["hit"][0] if trees["hit"] else None
vertex_tree = trees["vertex"][0] if trees["vertex"] else None

print(f"Loaded files: {len(open_files)}")
print("hit_tree entries:", hit_tree.GetEntries() if hit_tree else 0)
print("cluster_tree entries:", cluster_tree.GetEntries() if cluster_tree else 0)
print("residual_tree entries:", residual_tree.GetEntries() if residual_tree else 0)

In [None]:
# Quick check: list a few branches from the trees so they are clearly accessible
if hit_tree:
    hit_branches = [hit_tree.GetListOfBranches().At(i).GetName() for i in range(min(100, hit_tree.GetListOfBranches().GetEntries()))]
    #print("hit_tree branches (first 10):", hit_branches)
if cluster_tree:
    cluster_branches = [cluster_tree.GetListOfBranches().At(i).GetName() for i in range(min(10, cluster_tree.GetListOfBranches().GetEntries()))]
    #print("cluster_tree branches (first 10):", cluster_branches)
if residual_tree:
    residual_branches = [residual_tree.GetListOfBranches().At(i).GetName() for i in range(min(10, residual_tree.GetListOfBranches().GetEntries()))]
    #print("residual_tree branches (first 10):", residual_branches)

In [None]:
Npads = [94,128,192]
ADC_treshold_up = [20, 100, 1000000, 100000]
ADC_treshold_down = [0,  20, 60, 200  ]
Select_ADC_treshold = 2  # 0, 1, 2, 3

In [None]:
hists = {}

def get_hist(sector, imod, side):
    key = (sector, imod, side)
    if key in hists:
        return hists[key]

    hist = root.TH3F(
        f"hist3d_hard_sec{sector}_m{imod}_s{side}",
        f"3D ADC sec {sector} mod {imod} side {side}; timebin; pad; layer",
        300, 0,300,
        Npads[imod] , Npads[imod] * sector , Npads[imod] * (sector+1),
        16, 7 + 16*imod, 7 + 16*(imod + 1)
    )
    hists[key] = hist
    return hist

if hit_tree:
    n_entries = int(hit_tree.GetEntries())
    for entry in range(n_entries):
        hit_tree.GetEntry(entry)
        sector = int(hit_tree.sector)
        layer  = int(hit_tree.layer)
        imod   = (layer - 7) // 16
        if imod < 0 or imod > 2:
            continue

        side = int(hit_tree.side)
        #print( "Entry:", entry, "Layer:", layer, "Imod:", imod, "Side:", side , "Sector:", sector)
        if side not in (0, 1):
            continue

        
        #if side==0 and sector==0 and imod ==0 :
        #    print( hit_tree.adc)
        if(hit_tree.adc>ADC_treshold_down[Select_ADC_treshold]):
            get_hist(sector, imod, side).Fill( hit_tree.tbin, hit_tree.pad, layer, hit_tree.adc)
         #   print("Filled!")


In [None]:
hist3d_xyz = root.TH3F("hist3d_xyz", "3D ADC; x; y;z", 240, -60, 60, 240, -60, 60, 150,0, 150)
if hit_tree:
    for entry in range(hit_tree.GetEntries()):
        hit_tree.GetEntry(entry)
        x_hit = hit_tree.gx
        y_hit = hit_tree.gy
        z_hit = hit_tree.gz
        adc_hit = hit_tree.adc
        if (x_hit**2 + y_hit**2)**0.5 > 12:  # only fill if within 100 units
            hist3d_xyz.Fill(x_hit, y_hit, z_hit,adc_hit)


In [None]:
'''c1 = root.TCanvas("c1", "3D hits", 1200, 1000)
hist3d_xyz.Draw("colz")
c1.Draw()'''

# Build Tree

In [None]:
import numpy as np
from collections import defaultdict
from scipy.spatial import cKDTree

# collect points per (sector, imod, side)
points = defaultdict(list)   # key -> list of (t, pad, layer)
payload = defaultdict(list)  # key -> list of (adc, entry)

if hit_tree:
    n_entries = int(hit_tree.GetEntries())
    for entry in range(n_entries):
        hit_tree.GetEntry(entry)

        sector = int(hit_tree.sector)
        layer  = int(hit_tree.layer)
        imod   = (layer - 7) // 16
        if imod < 0 or imod > 2:
            continue

        side = int(hit_tree.side)
        if side not in (0, 1):
            continue

        adc = int(hit_tree.adc)
        if  (adc < ADC_treshold_down[Select_ADC_treshold]):
            continue

        tbin = int(hit_tree.tbin)
        pad  = int(hit_tree.pad)

        key = (sector, imod, side)
        points[key].append((tbin, pad, layer))
        payload[key].append((adc, entry))
        used_global = set()
        if sector==0 and imod==0 and side==0:
            if layer==7 and tbin<30:
            #if tbin>10 and tbin<30 and pad>60 and pad<80 and layer>=7 and layer<8:
                print(f"Entry {entry}: tbin={tbin}, pad={pad}, layer={layer}, adc={adc}")
        

# build KD trees
kdtree = {}
pts_arr = {}
for key, pts in points.items():
    arr = np.asarray(pts, dtype=np.float32)
    pts_arr[key] = arr
    kdtree[key] = cKDTree(arr)

print("Built KD-trees:", len(kdtree))


In [None]:
def box_query(key, tmin, tmax, pmin, pmax, lmin, lmax):
    tree = kdtree[key]
    arr  = pts_arr[key]

    # cheap preselect using a ball around the box center
    center = np.array([(tmin+tmax)/2.0, (pmin+pmax)/2.0, (lmin+lmax)/2.0], dtype=np.float32)
    half   = np.array([(tmax-tmin)/2.0, (pmax-pmin)/2.0, (lmax-lmin)/2.0], dtype=np.float32)
    radius = float(np.linalg.norm(half))  # ball that covers the whole box

    cand = tree.query_ball_point(center, r=radius)

    # exact box filter
    out = []
    for i in cand:
        t, p, l = arr[i]
        if (tmin <= t <= tmax) and (pmin <= p <= pmax) and (lmin <= l <= lmax):
            out.append(i)

    return out  # indices into payload[key]


In [None]:
def knn(key, t, p, l, k=10):
    d, idx = kdtree[key].query([t, p, l], k=k)
    return d, idx


In [None]:
key = (0, 0, 0)
idxs = box_query(key, tmin=10, tmax=30, pmin=50, pmax=60, lmin=7, lmax=7)

print(f"Box query found {len(idxs)} hits in key {key}:")

for i in idxs:
    tbin, pad, layer = pts_arr[key][i]
    adc, entry = payload[key][i]
    print(f"  Entry {entry}: tbin={int(tbin)}, pad={int(pad)}, layer={int(layer)}, adc={adc}")


In [None]:
import numpy as np
from collections import defaultdict
from scipy.spatial import cKDTree

def build_layer_indices_and_trees(key, pts_arr, payload):
    """
    Input:
      pts_arr[key]  : Nx3 array (tbin, pad, layer)
      payload[key]  : list length N of (adc, entry)
    Output:
      layers[layer] = dict with:
         "idx": global indices of hits in this layer
         "tp" : Mx2 float array of (tbin,pad)
         "adc": M int array
         "tree": cKDTree on tp
    """
    arr = pts_arr[key]
    adcs = np.array([payload[key][i][0] for i in range(len(payload[key]))], dtype=np.int32)
    
    by_layer = defaultdict(list)
    for i in range(arr.shape[0]):
        layer = int(arr[i, 2])
        by_layer[layer].append(i)

    layers = {}
    for layer, idxs in by_layer.items():
        idxs = np.asarray(idxs, dtype=np.int32)
        tp = arr[idxs][:, :2].astype(np.int32)  # (tbin, pad) as integers
        adc = adcs[idxs]
        tree = cKDTree(tp.astype(np.float32)) 
        layers[layer] = {"idx": idxs, "tp": tp, "adc": adc, "tree": tree}

    return layers


# Find local ADC maximums

In [None]:
def is_local_max(layer_data, j):
    """
    8-neighborhood local maximum in (timebin, pad), same layer:
    neighbors are |dt|<=1, |dp|<=1 (excluding itself)
    """
    tp   = layer_data["tp"]
    adc  = layer_data["adc"]
    tree = layer_data["tree"]

    a0 = int(adc[j])


    t0, p0 = tp[j]
    
    # radius that fully contains the 3x3 box
    r = np.sqrt(2.0)
    cand = tree.query_ball_point([t0, p0], r=r)

    for k in cand:
        if k == j:
            continue

        t, p = tp[k]
        if abs(t - t0) <= 1 and abs(p - p0) <= 1:
            if adc[k] >= a0:   # STRICT local max
                return False
    #print(f"Found local max at (t={t0}, p={p0}) with adc={a0}")
    return True



def find_seeds(layers):
    """
    Returns list of seeds as tuples: (layer, j_in_layer)
    """
    seeds = []
    for layer, ld in layers.items():
        M = ld["tp"].shape[0]
        for j in range(M):
            if is_local_max(ld, j):
                seeds.append((layer, j))
    return seeds


In [None]:
def query_box_in_layer(layer_data, t0, p0, dt, dp):
    tp   = layer_data["tp"]   # int32
    tree = layer_data["tree"]

    t0 = int(t0); p0 = int(p0)
    r = float((dt*dt + dp*dp) ** 0.5)
    cand = tree.query_ball_point([float(t0), float(p0)], r=r)

    out = []
    for j in cand:
        t = int(tp[j, 0]); p = int(tp[j, 1])
        if abs(t - t0) <= dt and abs(p - p0) <= dp:
            out.append(j)
    return out

# Chaining

In [None]:

def pick_best_candidate(layer_data, cand_js, t0, p0):
    tp  = layer_data["tp"]   # int32
    adc = layer_data["adc"]

    t0 = int(t0); p0 = int(p0)

    best = None
    best_key = None
    for j in cand_js:
        t = int(tp[j, 0]); p = int(tp[j, 1])
        dist2 = (t - t0)*(t - t0) + (p - p0)*(p - p0)
        key = (int(adc[j]), -dist2)
        if best is None or key > best_key:
            best = j
            best_key = key
    return best


In [None]:
def find_seeds_sorted(layers):
    seeds = []
    for layer, ld in layers.items():
        M = ld["tp"].shape[0]
        for j in range(M):
            if is_local_max(ld, j):
                seeds.append((layer, j, int(ld["adc"][j])))

    # highest ADC first
    seeds.sort(key=lambda x: x[2], reverse=True)
    return seeds


# Find horizontal chains

In [None]:
import numpy as np
from collections import defaultdict, deque
from scipy.spatial import cKDTree

def _order_chain_in_layer(layer_data, chain_seed_js):
    """
    Order a set of seed indices (j in this layer) into a 'nice' polyline order.
    Uses PCA-like 1D projection when possible; falls back to sorting by (t,p).
    """
    tp = layer_data["tp"]  # int32, shape (M,2)
    pts = np.asarray([tp[j] for j in chain_seed_js], dtype=np.float32)  # Nx2

    if len(chain_seed_js) <= 2:
        # stable ordering
        order = np.lexsort((pts[:, 1], pts[:, 0]))  # sort by t then pad
        return [chain_seed_js[i] for i in order]

    # PCA direction = first right-singular vector of centered coords
    c = pts.mean(axis=0, keepdims=True)
    X = pts - c
    # SVD on 2D: X = U S Vt
    _, _, vt = np.linalg.svd(X, full_matrices=False)
    direction = vt[0]  # 2-vector

    s = X @ direction  # projection coordinate
    order = np.argsort(s)
    return [chain_seed_js[i] for i in order]

In [None]:
def chain_avg_dp_dt_between_maxima(chain, layers, seed_set):
    """
    Returns (avg_dp, avg_dt, n_maxima) using absolute dp/dt
    between consecutive maxima in the given chain order.
    """
    maxima = [node for node in chain if node in seed_set]
    if len(maxima) < 2:
        return None, None, len(maxima)

    dp_diffs = []
    dt_diffs = []
    for (ly1, j1), (ly2, j2) in zip(maxima, maxima[1:]):
        t1, p1 = layers[int(ly1)]["tp"][int(j1)]
        t2, p2 = layers[int(ly2)]["tp"][int(j2)]
        dp_diffs.append(abs(int(p2) - int(p1)))
        dt_diffs.append(abs(int(t2) - int(t1)))

    return float(np.mean(dp_diffs)), float(np.mean(dt_diffs)), len(maxima)


In [None]:
from collections import deque

def build_horizontal_cluster_allhits_from_seed(layers, seed_layer, seed_j, used_global,
                                               dt=2, dp=1):
    """
    Flood-fill in ONE layer from a seed, collecting ALL connected hits (same layer),
    using |dt|<=dt, |dp|<=dp.

    Marks hits as used during the fill, but returns the list of global indices it claimed
    so caller can UNDO if the cluster is rejected by quality cuts.
    """
    layer = int(seed_layer)
    ld = layers.get(layer)
    if ld is None:
        return [], []

    seed_j = int(seed_j)
    gi_seed = int(ld["idx"][seed_j])
    if gi_seed in used_global:
        return [], []

    q = deque([seed_j])
    visited_local = set([seed_j])

    chain_js = []
    consumed_gis = []

    while q:
        j = int(q.popleft())
        gi = int(ld["idx"][j])

        # if already used by some earlier accepted chain, don't take it
        if gi in used_global:
            continue

        # claim it
        used_global.add(gi)
        consumed_gis.append(gi)
        chain_js.append(j)

        t0 = int(ld["tp"][j, 0])
        p0 = int(ld["tp"][j, 1])

        neigh = query_box_in_layer(ld, t0, p0, dt=dt, dp=dp)
        for nj in neigh:
            nj = int(nj)
            if nj in visited_local:
                continue
            visited_local.add(nj)
            q.append(nj)

    chain = [(layer, j) for j in chain_js]
    return chain, consumed_gis


In [None]:
def find_horizontal_chains_allhits(layers, seeds_sorted, used_global,
                                   dt=2, dp=1,
                                   min_hits=3,
                                   min_pad_span=5,      # <-- require (pmax - pmin) > 5
                                   order_for_drawing=True):
    """
    seeds_sorted: list of (layer, j, adc) sorted by adc desc
    Builds horizontal clusters; accepts only if:
      - len(cluster) >= min_hits
      - pad span (pmax - pmin) > min_pad_span
    If rejected: UNDO used_global claims for that cluster.
    """
    horizontal_chains = []

    for (layer, j, adc) in seeds_sorted:
        chain, consumed_gis = build_horizontal_cluster_allhits_from_seed(
            layers, layer, j, used_global, dt=dt, dp=dp
        )
        if not chain:
            continue

        ly = int(chain[0][0])
        ld = layers[ly]
        js = [jj for (_, jj) in chain]

        # compute pad span
        pads = [int(ld["tp"][jj, 1]) for jj in js]
        pad_span = (max(pads) - min(pads)) if pads else 0

        # acceptance cuts
        if len(js) < min_hits or pad_span <= min_pad_span:
            # rollback: free hits for vertical iteration / other chains
            for gi in consumed_gis:
                used_global.discard(int(gi))
            continue

        if order_for_drawing:
            js_ord = _order_chain_in_layer(ld, js)
            chain = [(ly, int(jj)) for jj in js_ord]

        horizontal_chains.append(chain)

    return horizontal_chains


# Build Vertical chains

In [None]:
def _wls_line(x, y, w):
    """
    Weighted least squares for y = a + b*x
    Returns (a, b).
    """
    sw = sx = sy = sxx = sxy = 0.0
    for xi, yi, wi in zip(x, y, w):
        wi = float(wi)
        if wi <= 0:
            continue
        sw  += wi
        sx  += wi * xi
        sy  += wi * yi
        sxx += wi * xi * xi
        sxy += wi * xi * yi

    if sw <= 0:
        return 0.0, 0.0

    denom = (sw * sxx - sx * sx)
    if denom == 0:
        a = sy / sw
        b = 0.0
        return a, b

    b = (sw * sxy - sx * sy) / denom
    a = (sy - b * sx) / sw
    return a, b


def _fit_from_pts(pts):
    """pts = [(L,t,p,w), ...] -> (at,bt, ap,bp)"""
    xs = [float(L) for (L, t, p, w) in pts]
    ts = [float(t) for (L, t, p, w) in pts]
    ps = [float(p) for (L, t, p, w) in pts]
    ws = [float(w) for (L, t, p, w) in pts]
    at, bt = _wls_line(xs, ts, ws)
    ap, bp = _wls_line(xs, ps, ws)
    return at, bt, ap, bp


def _build_anchor_index(layers, chains_main_list, chain_id_list=None):

    if chain_id_list is None:
        chain_id_list = list(range(len(chains_main_list)))

    anchors_by_layer = {}
    for cid in chain_id_list:
        chain = chains_main_list[cid]
        for L, j in chain:
            ld = layers.get(L)
            if ld is None:
                continue
            t = int(ld["tp"][j, 0])
            p = int(ld["tp"][j, 1])
            gi = int(ld["idx"][j])
            anchors_by_layer.setdefault(int(L), []).append(
                {"t": t, "p": p, "gi": gi, "chain_id": cid, "j": int(j)}
            )
    return anchors_by_layer



In [None]:

def build_chain_from_seed_unique_dir(
    layers, seed_layer, seed_j, used_global, seed_set,
    dt_win0=3, dp_win0=2,
    layer_step=1,
    dt_gate=1, dp_gate=1,
    allow_skip_one_layer=True,
    anchors_by_layer=None,
    allow_merge_to_anchors=True,
    allow_merge_same_layer=True,
):
    """
    Directional chaining that:
    - Uses ALL hits in window for fitting
    - Stores seed hits in chain_main
    - Stores non-seed hits in chain_support
    
    Returns: chain_main, chain_support, fit_history, merge_chain_id
    """
    chain_main = []
    chain_support = []
    fit_history = []
    merge_chain_id = None

    cur_layer = int(seed_layer)
    cur_j = int(seed_j)
    
    ld0 = layers.get(cur_layer)
    if ld0 is None:
        return chain_main, chain_support, fit_history, merge_chain_id

    gi0 = int(ld0["idx"][cur_j])
    if gi0 in used_global:
        return chain_main, chain_support, fit_history, merge_chain_id

    # Seed hit
    t0 = int(ld0["tp"][cur_j, 0])
    p0 = int(ld0["tp"][cur_j, 1])
    w0 = float(ld0["adc"][cur_j])
    ##print("--------------------------------------------------")
    #print("Starting chain from seed at layer", cur_layer, "j", cur_j, "t=", t0, "p=", p0, "w=", w0)
    chain_main.append((cur_layer, cur_j))
    used_global.add(gi0)
    
    # For fitting: store (L, t, p, w)
    pts = [(cur_layer, t0, p0, w0)]
    
    at = bt = ap = bp = None

    def _collect_hits_in_box(layer_num, t_center, p_center, dt, dp):
        """Returns (seed_js, support_js) both lists of j-indices"""
        ld = layers.get(layer_num)
        if ld is None:
            return [], []
        
        js = query_box_in_layer(ld, int(t_center), int(p_center), int(dt), int(dp))
        seed_js = []
        support_js = []
        #print("Layer", layer_num, "found", len(js), "hits in box centered at t=", t_center, "p=", p_center)
        for j in js:
            gi = int(ld["idx"][j])
            if gi in used_global:
                continue
            
            is_seed = (int(layer_num), int(j)) in seed_set
            if is_seed:
                seed_js.append(int(j))
                ##print("  Seed hit:", layer_num, j)
            else:
                support_js.append(int(j))
                #print("  Support hit:", layer_num, j)
        
        return seed_js, support_js

    def _find_anchor_merge(layer_num, t_center, p_center, dt, dp):
        """Look for pass-1 anchor hits for merging"""
        if not allow_merge_to_anchors or anchors_by_layer is None:
            return None
        
        lst = anchors_by_layer.get(int(layer_num), [])
        if not lst:
            return None

        # Count anchor hits in window by chain_id
        counts = {}
        for a in lst:
            #print("  Anchor hit at layer", layer_num, "t=", a["t"], "p=", a["p"], "chain_id=", a["chain_id"])
            if abs(a["t"] - t_center) <= dt and abs(a["p"] - p_center) <= dp:
                counts[a["chain_id"]] = counts.get(a["chain_id"], 0) + 1

        if not counts:
            return None

        # Return chain with max count
        return max(counts.items(), key=lambda x: x[1])[0]
    
    def _try_anchor_merge(layers_to_try, t_center, p_center, dt, dp):
        """Try anchor-merge in the given layer order; returns chain_id or None."""
        for L in layers_to_try:
            cid = _find_anchor_merge(int(L), t_center, p_center, dt, dp)
            if cid is not None:
                return cid
        return None


    def _accept_group(layer_num, seed_js, support_js):
        """Add hits to chains and fitting points, mark as used"""
        ld = layers[layer_num]
        
        # Process seed hits
        for j in seed_js:
            gi = int(ld["idx"][j])
            if gi in used_global:
                continue
            t = int(ld["tp"][j, 0])
            p = int(ld["tp"][j, 1])
            w = float(ld["adc"][j])
            chain_main.append((layer_num, j))
            used_global.add(gi)
            pts.append((layer_num, t, p, w))
            #print("Accepted seed hit:", layer_num, j, "t=", t, "p=", p, "w=", w)
        
        # Process support hits (for fitting only)
        for j in support_js:
            gi = int(ld["idx"][j])
            if gi in used_global:
                continue
            t = int(ld["tp"][j, 0])
            p = int(ld["tp"][j, 1])
            w = float(ld["adc"][j])
            chain_support.append((layer_num, j))
            used_global.add(gi)
            pts.append((layer_num, t, p, w))
            #print("Accepted support hit:", layer_num, j, "t=", t, "p=", p, "w=", w)
        

    def _refit_and_store():
        nonlocal at, bt, ap, bp
        at, bt, ap, bp = _fit_from_pts(pts)
        fit_history.append({
            "n_hits": len(pts),
            "n_seeds": len(chain_main),
            "n_support": len(chain_support),
            "t_fit": {"a": at, "b": bt},
            "p_fit": {"a": ap, "b": bp},
        })
        #print(f"Refit: n_hits={len(pts)}, n_seeds={len(chain_main)}, n_support={len(chain_support)}, at={at:.2f}, bt={bt:.4f}, ap={ap:.2f}, bp={bp:.4f}")

    # ---- Step 1: First group in next layer ----
    next_layer = cur_layer + layer_step
    seed_js, support_js = _collect_hits_in_box(next_layer, t0, p0, dt_win0, dp_win0)

    # Try skipping one layer if nothing found
    if not seed_js and not support_js and allow_skip_one_layer:
        #print("No hits found in next layer, trying skip one layer")
        next2 = next_layer + layer_step
        seed_js, support_js = _collect_hits_in_box(next2, t0, p0, dt_win0, dp_win0)
        if seed_js or support_js:
            next_layer = next2

    # Check for merge if no hits found
    '''if not seed_js and not support_js:
        if allow_merge_to_anchors:
            cid = _find_anchor_merge(next_layer, t0, p0, dt_win0, dp_win0)
            if cid is None and allow_skip_one_layer:
                cid = _find_anchor_merge(next_layer + layer_step, t0, p0, dt_win0, dp_win0)
            if cid is not None:
                merge_chain_id = cid
        return chain_main, chain_support, fit_history, merge_chain_id'''
    # Check for merge if no hits found
    if not seed_js and not support_js:
        if allow_merge_to_anchors:
            layers_to_try = [next_layer]
            if allow_skip_one_layer:
                layers_to_try.append(next_layer + layer_step)
            if allow_merge_same_layer:
                layers_to_try.insert(0, cur_layer)   # <--- SAME LAYER FIRST

            cid = _try_anchor_merge(layers_to_try, t0, p0, dt_win0, dp_win0)
            if cid is not None:
                merge_chain_id = cid
        return chain_main, chain_support, fit_history, merge_chain_id


    _accept_group(next_layer, seed_js, support_js)
    _refit_and_store()
    cur_layer = next_layer

    # ---- Subsequent steps: Fit prediction with gating ----
    while True:
        target = cur_layer + layer_step
        if target not in layers:
            break

        t_pred = at + bt * float(target)
        p_pred = ap + bp * float(target)
        #print(f"Predicting for layer {target}: t={t_pred:.2f}, p={p_pred:.2f}")
        seed_js, support_js = _collect_hits_in_box(target, t_pred, p_pred, dt_gate, dp_gate)

        # Try skip if empty
        if not seed_js and not support_js and allow_skip_one_layer:
            target2 = target + layer_step
            if target2 in layers:
                t_pred2 = at + bt * float(target2)
                p_pred2 = ap + bp * float(target2)
                seed_js, support_js = _collect_hits_in_box(target2, t_pred2, p_pred2, dt_gate, dp_gate)
                if seed_js or support_js:
                    target = target2
                    t_pred = t_pred2
                    p_pred = p_pred2

        # Check for merge
        if not seed_js and not support_js:
            #print("Look for merge at layer", target)
            if allow_merge_to_anchors:
                cid = _find_anchor_merge(target, t_pred, p_pred, dt_gate, dp_gate)
                if cid is None and allow_skip_one_layer:
                    cid = _find_anchor_merge(target + layer_step, t_pred, p_pred, dt_gate, dp_gate)
                if cid is not None:
                    merge_chain_id = cid
            break

        _accept_group(target, seed_js, support_js)
        _refit_and_store()
        cur_layer = target

    return chain_main, chain_support, fit_history, merge_chain_id

In [None]:
############### TOP TO BOTTOM ###############
def two_pass_chaining(
    layers, seeds_sorted, seed_set,
    dt_win0=3, dp_win0=2, dt_gate=1, dp_gate=1,
    allow_skip_one_layer=True
):
    """
    Two-pass chaining with corrected logic:
    PASS 1: TOP → BOTTOM (creates anchor chains)
    PASS 2: BOTTOM → TOP (merges into anchors)
    Returns: list of (chain_main, chain_support) tuples
    """
    used = set()
    
    # ---- PASS 1: TOP → BOTTOM (no merging, creates anchors) ----
    print("="*60)
    print("PASS 1: TOP → BOTTOM (creating anchor chains)")
    print("="*60)
    
    seeds_topdown = sorted(seeds_sorted, key=lambda x: -int(x[0]))
    chains_pass1 = []
    
    for ly, j, adc in seeds_topdown:
        chain_m, chain_s, hist, merge_id = build_chain_from_seed_unique_dir(
            layers, int(ly), int(j), used, seed_set,
            dt_win0=dt_win0, dp_win0=dp_win0,
            layer_step=-1,  # Going DOWN
            dt_gate=dt_gate, dp_gate=dp_gate,
            allow_skip_one_layer=allow_skip_one_layer,
            anchors_by_layer=None,  # No anchors in pass 1
            allow_merge_to_anchors=False
        )
        if len(chain_m) >= 2:
            chains_pass1.append((chain_m, chain_s))
            #print(f"Pass 1 chain {len(chains_pass1)-1}: {len(chain_m)} main + {len(chain_s)} support hits")
    
    print(f"\nPass 1 complete: {len(chains_pass1)} chains created")
    
    # Build anchor index from pass-1 main chains
    anchors_by_layer = _build_anchor_index(layers, [c[0] for c in chains_pass1])
    
    # Debug: print anchor counts
    total_anchors = sum(len(v) for v in anchors_by_layer.values())
    #print(f"Built anchor index: {total_anchors} anchors across {len(anchors_by_layer)} layers")
    #for layer_num in sorted(anchors_by_layer.keys()):
        #print(f"  Layer {layer_num}: {len(anchors_by_layer[layer_num])} anchors")
    
    # ---- PASS 2: BOTTOM → TOP (with merging) ----
    print("\n" + "="*60)
    print("PASS 2: BOTTOM → TOP (merging into anchors)")
    print("="*60)
    
    seeds_bottomup = sorted(seeds_sorted, key=lambda x: int(x[0]))
    merged_chain_ids = set()
    chains_pass2 = []
    
    for ly, j, adc in seeds_bottomup:
        chain_m, chain_s, hist, merge_id = build_chain_from_seed_unique_dir(
            layers, int(ly), int(j), used, seed_set,
            dt_win0=dt_win0, dp_win0=dp_win0,
            layer_step=+1,  # Going UP
            dt_gate=dt_gate, dp_gate=dp_gate,
            allow_skip_one_layer=allow_skip_one_layer,
            anchors_by_layer=anchors_by_layer,
            allow_merge_to_anchors=True,
            allow_merge_same_layer=True, 
        )
        
        if len(chain_m) < 2:
            continue
        
        if merge_id is not None:
            #print(f"MERGE DETECTED: Pass 2 chain merging into Pass 1 chain {merge_id}")
            
            if merge_id not in merged_chain_ids:
                # First time merging into this anchor chain
                merged_chain_ids.add(merge_id)
                
                # Combine: pass2 chain + pass1 anchor chain
                merged_main = list(chain_m) + list(chains_pass1[merge_id][0])
                merged_support = list(chain_s) + list(chains_pass1[merge_id][1])
                
                chains_pass2.append((merged_main, merged_support))
                #print(f"  Created merged chain: {len(merged_main)} main + {len(merged_support)} support")
            else:
                # Already merged into this anchor - just add as standalone
                #print(f"  Anchor {merge_id} already merged, adding as standalone chain")
                chains_pass2.append((chain_m, chain_s))
        else:
            # No merge, add as standalone
            chains_pass2.append((chain_m, chain_s))
            #print(f"Pass 2 standalone chain: {len(chain_m)} main + {len(chain_s)} support hits")
    
    print(f"\nPass 2 complete: {len(chains_pass2)} chains created ({len(merged_chain_ids)} merged)")
    
    # ---- COMBINE RESULTS ----
    # Add unmerged pass-1 chains + all pass-2 chains
    final_chains = []
    
    for i, (chain_m, chain_s) in enumerate(chains_pass1):
        if i not in merged_chain_ids:
            final_chains.append((chain_m, chain_s))
    
    final_chains.extend(chains_pass2)
    
    print(f"\nFinal result: {len(final_chains)} total chains")
    print(f"  {len(chains_pass1) - len(merged_chain_ids)} unmerged from pass 1")
    print(f"  {len(chains_pass2)} from pass 2 (including {len(merged_chain_ids)} merged)")
    print("="*60)
    
    return final_chains


def _build_anchor_index(layers, chain_mains):
    """
    Build index of anchor hits by layer.
    Returns: {layer_num: [{"t": t, "p": p, "chain_id": i}, ...]}
    """
    anchors = {}
    
    for chain_id, chain_main in enumerate(chain_mains):
        for layer_num, j in chain_main:
            ld = layers.get(int(layer_num))
            if ld is None:
                continue
            
            t = int(ld["tp"][j, 0])
            p = int(ld["tp"][j, 1])
            
            if layer_num not in anchors:
                anchors[layer_num] = []
            
            anchors[layer_num].append({
                "t": t,
                "p": p,
                "chain_id": chain_id
            })
    
    return anchors

In [None]:
def dump_chain(key, chain_main, chain_support, layers, pts_arr, payload, max_lines=200):
    print(f"\n=== CHAIN key={key} main_hits={len(chain_main)} support_hits={len(chain_support)} ===")

    def print_hit(layer, j):
        gi = int(layers[layer]["idx"][j])  # global index
        tbin, pad, layer0 = pts_arr[key][gi]
        adc, entry = payload[key][gi]
        print(f"  L={int(layer0):2d}  t={int(tbin):3d}  pad={int(pad):4d}  adc={int(adc):5d}  entry={entry}")

    print("Main hits:")
    for (layer, j) in chain_main[:max_lines]:
        print_hit(layer, j)

    # If you want to see supports too:
    print("Support hits:")
    for (layer, j) in chain_support[:max_lines]:
        print_hit(layer, j)


In [None]:
'''
key = (0, 0, 0)
layers = build_layer_indices_and_trees(key, pts_arr, payload)       
seeds_sorted = find_seeds_sorted(layers)
print(f"\nFound {len(seeds_sorted)} seeds in key {key}:")
for i, (layer, j,adc) in enumerate(seeds_sorted):

    #print(f"\nSeed {layer}: Layer={j[0]}  j_in_layer={j[1]}  adc={j[2]}")
    gi = int(layers[layer]["idx"][j])  # global index in pts_arr[key]
    tbin, pad, layer0 = pts_arr[key][gi]
    if layer==7:
        print(f"Seed {i:3d}: L={int(layer0):2d}  t={int(tbin):3d}  pad={int(pad):4d}  adc={int(layers[layer]['adc'][j]):5d}  entry={payload[key][gi][1]}")
'''

In [None]:
'''key = (0, 0, 0)
layers = build_layer_indices_and_trees(key, pts_arr, payload)

seeds_sorted = find_seeds_sorted(layers)   # (layer, j, adc)
print("Total number of ADC maximums:", len(seeds_sorted))

used_global = set()
seed_set = set((int(ly), int(j)) for (ly, j, adc) in seeds_sorted)
# ----------------------------------------------------------------------

# ---- 1st iteration: horizontal chains (same layer) INCLUDING ALL hits ----
horizontal_chains = find_horizontal_chains_allhits(
    layers, seeds_sorted, used_global,
    dt=2, dp=1,
    min_hits=3,
    min_pad_span=5,
    order_for_drawing=True
)

horizontal_chain_avg_dpdts = []
for chain in horizontal_chains:
    avg_dp, avg_dt, n_max = chain_avg_dp_dt_between_maxima(chain, layers, seed_set)
    horizontal_chain_avg_dpdts.append((avg_dp, avg_dt, n_max))

# ----------------------------------------------------------------------
# NEW: vertical chaining = TWO PASS (top->bottom) + (bottom->top w/ merge)
# IMPORTANT: reuse the SAME used_global that horizontal pass already filled.
# ----------------------------------------------------------------------

# Use ALL seeds; two_pass_chaining will skip those already used_global.
seed_set = {(ly, j) for ly, j, adc in seeds_sorted}

final_vertical_chains = two_pass_chaining(
    layers, seeds_sorted, seed_set,  # Add seed_set
    dt_win0=3, dp_win0=2,
    dt_gate=1, dp_gate=1,
    allow_skip_one_layer=True
)

for chain_main, chain_support in final_vertical_chains[3:4]:
    avg_dp, avg_dt, n_max = chain_avg_dp_dt_between_maxima(
        chain_main, layers, seed_set
    )
    dump_chain(key, chain_main, chain_support, layers, pts_arr, payload)'''

# Fit

In [None]:
import numpy as np
from scipy.optimize import curve_fit

def sagitta_fit_3d(chain, key, layers, pts_arr, payload):
    """
    Perform weighted parabolic (sagitta) fit on a chain of hits.
    
    Returns:
        fit_params: dict with fit parameters and quality metrics
        fit_line: TPolyLine3D for drawing the fit
    """
    if len(chain) < 3:
        return None, None
    
    # Extract hit coordinates and ADC weights
    points = []
    weights = []
    
    for (layer, j) in chain:
        gi = int(layers[layer]["idx"][j])
        tbin, pad, layer_val = pts_arr[key][gi]
        adc, entry = payload[key][gi]
        
        points.append([float(tbin), float(pad), float(layer_val)])
        weights.append(float(adc))
    
    points = np.array(points)
    weights = np.array(weights)
    
    # Fit pad vs layer (parabola in pad-layer plane)
    # pad = a*layer^2 + b*layer + c
    # tbin = d*layer + e
    
    layers_arr = points[:, 2]
    pads_arr = points[:, 1]
    tbins_arr = points[:, 0]
    
    # Weighted parabolic fit for pad vs layer
    try:
        # Fit pad = a*layer^2 + b*layer + c
        def parabola(x, a, b, c):
            return a * x**2 + b * x + c
        
        popt_pad, pcov_pad = curve_fit(
            parabola, layers_arr, pads_arr, 
            sigma=1.0/np.sqrt(weights),
            absolute_sigma=False
        )
        
        # Linear fit for tbin vs layer
        def linear(x, d, e):
            return d * x + e
        
        popt_tbin, pcov_tbin = curve_fit(
            linear, layers_arr, tbins_arr,
            sigma=1.0/np.sqrt(weights),
            absolute_sigma=False
        )
        
    except Exception as e:
        print(f"Fit failed: {e}")
        return None, None
    
    # Calculate residuals and chi-squared
    pad_fit = parabola(layers_arr, *popt_pad)
    tbin_fit = linear(layers_arr, *popt_tbin)
    
    pad_residuals = pads_arr - pad_fit
    tbin_residuals = tbins_arr - tbin_fit
    
    # Weighted chi-squared
    chi2_pad = np.sum(weights * pad_residuals**2) / np.sum(weights)
    chi2_tbin = np.sum(weights * tbin_residuals**2) / np.sum(weights)
    
    # Calculate sagitta (maximum deviation from straight line)
    layer_min, layer_max = layers_arr.min(), layers_arr.max()
    if layer_max == layer_min or layer_max - layer_min <= 1:
        return None, None
    layer_mid = (layer_min + layer_max) / 2.0
    
    # Straight line connecting endpoints in pad-layer plane
    pad_start = parabola(layer_min, *popt_pad)
    pad_end = parabola(layer_max, *popt_pad)
    pad_straight = pad_start + (pad_end - pad_start) * (layer_mid - layer_min) / (layer_max - layer_min)
    pad_parabola = parabola(layer_mid, *popt_pad)
    sagitta = abs(pad_parabola - pad_straight)
    
    fit_params = {
        'pad_params': popt_pad,  # [a, b, c]
        'tbin_params': popt_tbin,  # [d, e]
        'chi2_pad': chi2_pad,
        'chi2_tbin': chi2_tbin,
        'sagitta': sagitta,
        'n_hits': len(chain),
        'total_adc': np.sum(weights)
    }
    
    # Create smooth fit line for drawing
    layer_range = np.linspace(layer_min, layer_max, 50)
    pad_smooth = parabola(layer_range, *popt_pad)
    tbin_smooth = linear(layer_range, *popt_tbin)
    
    # Create TPolyLine3D for the fit
    fit_line = root.TPolyLine3D(len(layer_range))
    for i, (t, p, l) in enumerate(zip(tbin_smooth, pad_smooth, layer_range)):
        fit_line.SetPoint(i, float(t), float(p), float(l))
    
    return fit_params, fit_line


'''def print_fit_quality(fit_params, chain_type=""):
    """Print fit quality metrics"""
    if fit_params is None:
        print(f"{chain_type} Fit failed")
        return
    
    print(f"\n{chain_type} Fit Results:")
    print(f"  Sagitta: {fit_params['sagitta']:.4f}")
    print(f"  Chi2 (pad): {fit_params['chi2_pad']:.4f}")
    print(f"  Chi2 (tbin): {fit_params['chi2_tbin']:.4f}")
    print(f"  N hits: {fit_params['n_hits']}")
    print(f"  Total ADC: {fit_params['total_adc']:.1f}")
    print(f"  Pad params [a,b,c]: {fit_params['pad_params']}")
    print(f"  Tbin params [d,e]: {fit_params['tbin_params']}")
'''


# Drawing

In [None]:
def draw_chain_3d(chain_main, key, layers, pts_arr, color, width=2):
    """
    chain_main: list of (layer, j_in_layer) in order
    """
    n = len(chain_main)
    if n < 2:
        return None

    pl = root.TPolyLine3D(n)
    pl.SetLineColor(color)
    pl.SetLineWidth(width)

    for i, (layer, j) in enumerate(chain_main):
        gi = layers[layer]["idx"][j]  # global index
        tbin, pad, layer0 = pts_arr[key][gi]
        #print(f"Chain point {i}: layer={layer0}, tbin={tbin}, pad={pad}")
        pl.SetPoint(i, float(tbin), float(pad), float(layer0))

    pl.Draw("same")
    return pl

CHAIN_COLORS = [
    root.kRed + 1,
    root.kAzure + 2,
    root.kGreen + 2,
    root.kMagenta + 1,
    root.kOrange + 7,
    root.kCyan + 2,
    root.kViolet,
    root.kPink + 9,
]


In [None]:
def draw_local_maxima_3d(seeds, key, layers, pts_arr, color, mstyle=20, msize=1.2):
    """
    seeds: list of (layer, j_in_layer)
    Draw as TPolyMarker3D at (tbin, pad, layer)
    """
    n = len(seeds)
    if n == 0:
        return None

    pm = root.TPolyMarker3D(n)
    pm.SetMarkerColor(color)
    pm.SetMarkerStyle(mstyle)
    pm.SetMarkerSize(msize)

    for i, (layer, j) in enumerate(seeds):
        gi = int(layers[layer]["idx"][j])  # global index in pts_arr[key]
        tbin, pad, layer0 = pts_arr[key][gi]
        pm.SetPoint(i, float(tbin), float(pad), float(layer0)+0.5)

    pm.Draw("same")   # important: draw on same canvas
    return pm


In [None]:
'''def print_final_chains_layer_by_layer(
    final_chains, layers,
    max_chains=None,
    sort_layers=True,
    layer_step_hint=None,
    layer_min=None,
    layer_max=None,
):
    def _hit_info(layer, j):
        ld = layers[int(layer)]
        gi = int(ld["idx"][int(j)])
        t  = int(ld["tp"][int(j), 0])
        p  = int(ld["tp"][int(j), 1])
        w  = float(ld["adc"][int(j)])
        return gi, t, p, w

    n = len(final_chains)
    nprint = n if max_chains is None else min(n, int(max_chains))

    print("\n" + "="*90)
    print(f"FINAL CHAINS: total={n}, printing={nprint}  (layer_min={layer_min}, layer_max={layer_max})")
    print("="*90)

    for ic in range(nprint):
        chain_main, chain_support = final_chains[ic]

        by_layer = {}
        for ly, j in chain_main:
            ly = int(ly); j = int(j)
            if layer_min is not None and ly < layer_min: continue
            if layer_max is not None and ly > layer_max: continue
            by_layer.setdefault(ly, {"main": [], "support": []})["main"].append(j)

        for ly, j in chain_support:
            ly = int(ly); j = int(j)
            if layer_min is not None and ly < layer_min: continue
            if layer_max is not None and ly > layer_max: continue
            by_layer.setdefault(ly, {"main": [], "support": []})["support"].append(j)

        if not by_layer:
            continue  # nothing in requested layer range

        layers_list = list(by_layer.keys())
        if sort_layers:
            if layer_step_hint in (+1, -1):
                layers_list.sort(reverse=(layer_step_hint == -1))
            else:
                layers_list.sort()

        n_main = sum(len(by_layer[ly]["main"]) for ly in layers_list)
        n_sup  = sum(len(by_layer[ly]["support"]) for ly in layers_list)

        print("\n" + "-"*90)
        print(f"Chain {ic}: n_main={n_main}  n_support={n_sup}  n_layers={len(layers_list)} (filtered)")
        print("-"*90)

        for ly in layers_list:
            mains = by_layer[ly]["main"]
            sups  = by_layer[ly]["support"]
            print(f"  Layer {ly:3d}: main={len(mains):2d}, support={len(sups):2d}")

            for j in sorted(mains):
                gi, t, p, w = _hit_info(ly, j)
                print(f"    M  j={j:4d}  gi={gi:7d}  t={t:5d}  p={p:4d}  adc={w:8.2f}")

            for j in sorted(sups):
                gi, t, p, w = _hit_info(ly, j)
                print(f"    S  j={j:4d}  gi={gi:7d}  t={t:5d}  p={p:4d}  adc={w:8.2f}")
'''

# Final Cell

In [None]:
draw_sector = 0
draw_side   = 0

canvases = []
drawn = {}
drawn_chains = {}
drawn_fits = {}         # store fit line references (if you re-enable fits)
chain_avg_dpdts = {}
chain_fit_params = {}   # store fit parameters / histories
drawn_seeds  = {}


for imod in range(3):
    c = root.TCanvas(
        f"c_sec{draw_sector}_s{draw_side}_mod{imod}",
        f"Sector {draw_sector} Side {draw_side} Module {imod}",
        1500, 1000
    )
    c.cd()

    key = (draw_sector, imod, draw_side)

    if key in hists:
        h3 = hists[key]
        h3.SetTitle(
            f"3D ADC; timebin; pad; layer "
            f"(sec {draw_sector}, side {draw_side}, mod {imod})"
        )
        h3.Draw("SCAT")
        drawn[imod] = h3

        # ---- build per-key structures ----
        layers = build_layer_indices_and_trees(key, pts_arr, payload)

        seeds_sorted = find_seeds_sorted(layers)  # (layer, j, adc)
        #print(f"[sec {draw_sector} side {draw_side} mod {imod}] Seeds found:", len(seeds_sorted))

        seeds_for_markers = [(ly, j) for (ly, j, adc) in seeds_sorted]

        used_global = set()
        seed_set = set((int(ly), int(j)) for (ly, j, adc) in seeds_sorted)

        # ---- draw local maxima markers ----
        pm = draw_local_maxima_3d(
            seeds_for_markers, key, layers, pts_arr,
            color=root.kRed + 1,
            mstyle=20,
            msize=1.6
        )
        drawn_seeds[imod] = pm

        drawn_chains[imod] = []
        drawn_fits[imod] = []
        chain_fit_params[imod] = {'horizontal': [], 'vertical': []}

        # ============================================================
        # 1st iteration: HORIZONTAL (same-layer) CHAINS
        # ============================================================
        horizontal_chains = find_horizontal_chains_allhits(
            layers, seeds_sorted, used_global,
            dt=2, dp=3,
            min_hits=3,
            min_pad_span=5,
            order_for_drawing=True
        )

        horizontal_chain_avg_dpdts = []

        for ic, chain in enumerate(horizontal_chains):
            avg_dp, avg_dt, n_max = chain_avg_dp_dt_between_maxima(chain, layers, seed_set)
            horizontal_chain_avg_dpdts.append((avg_dp, avg_dt, n_max))

            color = CHAIN_COLORS[ic % len(CHAIN_COLORS)]

            # Draw original chain (thin line)
            pl = draw_chain_3d(chain, key, layers, pts_arr, color=color, width=2)
            if pl:
                drawn_chains[imod].append(pl)
            

            # If you later re-enable horizontal fits, keep it here
            # fit_params, fit_line = sagitta_fit_3d(chain, key, layers, pts_arr, payload)
            # if fit_line:
            #     fit_line.SetLineColor(color); fit_line.SetLineWidth(6); fit_line.SetLineStyle(1)
            #     fit_line.Draw("same"); drawn_fits[imod].append(fit_line)
            #     chain_fit_params[imod]['horizontal'].append(fit_params)

        # ============================================================
        # 2nd iteration: VERTICAL chains (TWO PASS)
        # ============================================================
        vertical_chain_avg_dpdts = []

        # Call two_pass_chaining ONCE. Reuse used_global already filled by horizontal pass.
        # If your two_pass_chaining signature doesn't yet accept used_global, this try/except
        # will fall back (but you really should add used_global support).
        seed_set = {(ly, j) for ly, j, adc in seeds_sorted}

        final_vertical_chains = two_pass_chaining(
            layers, seeds_sorted, seed_set,
            dt_win0=3, dp_win0=2,
            dt_gate=1, dp_gate=1,
            allow_skip_one_layer=True
        )
        ic = 0
        for chain_main, chain_support in final_vertical_chains:
            avg_dp, avg_dt, n_max = chain_avg_dp_dt_between_maxima(
                chain_main, layers, seed_set
            )
            
           
        
            '''print("="*68)
            print(f"LEN len(chain_main) {len(chain_main)}")
            print("="*68)'''
            if len(chain_main) < 3:
                continue
            color = CHAIN_COLORS[ic % len(CHAIN_COLORS)]
            '''fit_params, fit_line = sagitta_fit_3d(chain_main, key, layers, pts_arr, payload)
            if fit_line:
                fit_line.SetLineColor(color)
                fit_line.SetLineWidth(6)
                fit_line.SetLineStyle(1)  # dashed line for fit
                fit_line.Draw("same")
                drawn_fits[imod].append(fit_line)
                chain_fit_params[imod]['vertical'].append(fit_params)
                
                # Print fit quality for first few chains
                if ic < 3:
                    print(f"\nVertical Chain {ic} (module {imod}):")
                    print_fit_quality(fit_params, f"  ")'''



            # Draw vertical chain (thicker line)
            pl = draw_chain_3d(chain_main, key, layers, pts_arr, color=color, width=4)
            if pl:
                drawn_chains[imod].append(pl)
            ic += 1

            

           

            # If you want to dump/inspect:
            # dump_chain(key, chain_main, [], layers, pts_arr, payload)

            # If you later re-enable vertical fits, keep it here
            # fit_params, fit_line = sagitta_fit_3d(
            #     chain_main, key, layers, pts_arr, payload,
            #     use_perigee_for_vertical=False
            # )
            # chain_fit_params[imod]['vertical'].append({"imod": imod, "fit": fit_params})
            # if fit_line:
            #     fit_line.SetLineColor(color); fit_line.SetLineWidth(6); fit_line.SetLineStyle(1)
            #     fit_line.Draw("same"); drawn_fits[imod].append(fit_line)

        chain_avg_dpdts[imod] = {
            "horizontal": horizontal_chain_avg_dpdts,
            "vertical": vertical_chain_avg_dpdts,
        }

    print("----------------------------------------------------------")
    print(f"[sec {draw_sector} side {draw_side} mod {imod}] Vertical chains found:", len(final_vertical_chains))
    print("----------------------------------------------------------")
    c.Update()
    canvases.append(c)

for c in canvases:
    c.Draw()

# Print summary statistics
'''print("\n" + "="*60)
print("="*60)
for imod in range(3):
    if imod in chain_fit_params:
        h_fits = chain_fit_params[imod]['horizontal']
        v_fits = chain_fit_params[imod]['vertical']

        print(f"\nModule {imod}:")
        print(f"  Horizontal chains (fits stored): {len(h_fits)}")
        print(f"  Vertical chains (fits stored):   {len(v_fits)}")'''


In [None]:
phi_bin_width = [0.0053073, 0.00530732, 0.00530731]
module_radius = [
    [29.854978828112735, 31.869737083177956, 32.43665978627038, 33.00171100689825, 33.56863172731403, 34.133682357783, 34.70060474122243, 35.26565540941076, 35.83257683544541, 36.39762877363545, 36.964549975549694, 37.52960055896088, 38.09652180558749, 38.66157293473739, 39.228495272708216, 39.793545257944906],
    [41.65920253621078, 42.67990048015332, 43.7005755287188, 44.7212729094545, 45.7419615067264, 46.76264656230158, 47.78333428983602, 48.80401878201343, 49.82471910526506, 50.8454060012135, 51.866093793785126, 52.88677964073831, 53.90746625152035, 54.92815969895385, 55.948864895868056, 56.9695394315422],
    [58.910963349324035, 60.00800996331871, 61.10505851260341, 62.202104676954924, 63.29915863086735, 64.39619682986867, 65.49324606923312, 66.59029899562653, 67.68734047670296, 68.78439383353172, 69.88143340055497, 70.97848786511186, 72.07553264226554, 73.17257662017182, 74.2696338511705, 75.36667517343196],
]

In [None]:

# Store all chain hits with classifications and draw in phi-radius-timebin coordinates
# All sectors and modules combined per side

import ROOT as root
import numpy as np

# Storage for all chain hits
all_chain_hits = {}  # {(sector, side, imod): [chain_data]}

# Process ALL sectors and modules
for sector in range(24):  # All 24 sectors
    for imod in range(3):
        for side in [0, 1]:
            key = (sector, imod, side)
            
            if key not in hists:
                continue
            
            # Rebuild layers for this module
            layers = build_layer_indices_and_trees(key, pts_arr, payload)
            seeds_sorted = find_seeds_sorted(layers)
            seed_set = {(ly, j) for ly, j, adc in seeds_sorted}
            used_global = set()
            
            # Find horizontal chains
            horizontal_chains = find_horizontal_chains_allhits(
                layers, seeds_sorted, used_global,
                dt=2, dp=3, min_hits=3, min_pad_span=5, order_for_drawing=True
            )
            
            # Find vertical chains (two-pass)
            final_vertical_chains = two_pass_chaining(
                layers, seeds_sorted, seed_set,
                dt_win0=3, dp_win0=2, dt_gate=1, dp_gate=1,
                allow_skip_one_layer=True
            )
            
            # Store data for this module
            all_chain_hits[key] = []
            
            # Process horizontal chains
            for ic, chain in enumerate(horizontal_chains):
                chain_data = {
                    'chain_id': f'S{sector}_M{imod}_H{ic}',
                    'chain_type': 'horizontal',
                    'sector': sector,
                    'module': imod,
                    'side': side,
                    'hits': []
                }
                
                for layer_num, j in chain:
                    ld = layers.get(int(layer_num))
                    if ld is None:
                        continue
                    
                    t = int(ld["tp"][j, 0])
                    p = int(ld["tp"][j, 1])
                    adc = float(ld["adc"][j])
                    
                    # Convert to phi-radius coordinates
                    # Add sector offset to phi
                    phi_local = p * phi_bin_width[imod]
                    phi_global = phi_local + sector * (2 * np.pi / 24)
                    radius = module_radius[imod][(layer_num - 7) % 16]
                    
                    is_main = (int(layer_num), int(j)) in seed_set
                    
                    chain_data['hits'].append({
                        'layer': layer_num,
                        'timebin': t,
                        'pad': p,
                        'phi': phi_global,
                        'radius': radius,
                        'adc': adc,
                        'hit_type': 'main' if is_main else 'support'
                    })
                
                all_chain_hits[key].append(chain_data)
            
            # Process vertical chains
            for ic, (chain_main, chain_support) in enumerate(final_vertical_chains):
                if len(chain_main) < 3:
                    continue
                    
                chain_data = {
                    'chain_id': f'S{sector}_M{imod}_V{ic}',
                    'chain_type': 'vertical',
                    'sector': sector,
                    'module': imod,
                    'side': side,
                    'hits': []
                }
                
                # Process main hits
                for layer_num, j in chain_main:
                    ld = layers.get(int(layer_num))
                    if ld is None:
                        continue
                    
                    t = int(ld["tp"][j, 0])
                    p = int(ld["tp"][j, 1])
                    adc = float(ld["adc"][j])
                    
                    phi_local = p * phi_bin_width[imod]
                    phi_global = phi_local + sector * (2 * np.pi / 24)
                    radius = module_radius[imod][(layer_num - 7) % 16]
                    
                    chain_data['hits'].append({
                        'layer': layer_num,
                        'timebin': t,
                        'pad': p,
                        'phi': phi_global,
                        'radius': radius,
                        'adc': adc,
                        'hit_type': 'main'
                    })
                
                # Process support hits
                for layer_num, j in chain_support:
                    ld = layers.get(int(layer_num))
                    if ld is None:
                        continue
                    
                    t = int(ld["tp"][j, 0])
                    p = int(ld["tp"][j, 1])
                    adc = float(ld["adc"][j])
                    
                    phi_local = p * phi_bin_width[imod]
                    phi_global = phi_local + sector * (2 * np.pi / 24)
                    radius = module_radius[imod][(layer_num - 7) % 16]
                    
                    chain_data['hits'].append({
                        'layer': layer_num,
                        'timebin': t,
                        'pad': p,
                        'phi': phi_global,
                        'radius': radius,
                        'adc': adc,
                        'hit_type': 'support'
                    })
                
                all_chain_hits[key].append(chain_data)

# Print summary
print("\n" + "="*60)
print("CHAIN HITS STORAGE SUMMARY")
print("="*60)
total_chains = 0
for key, chains in all_chain_hits.items():
    sector, imod, side = key
    total_chains += len(chains)

print(f"Total chains stored: {total_chains}")
for side in [0, 1]:
    side_chains = sum(len(chains) for key, chains in all_chain_hits.items() if key[2] == side)
    print(f"  Side {side}: {side_chains} chains")

# Create visualizations - ONE canvas per side with ALL sectors and modules combined
vis_canvases = []

for side in [0, 1]:
    c = root.TCanvas(
        f"c_phi_radius_side{side}_all",
        f"Phi-Radius-Timebin View - Side {side} (All Sectors & Modules)",
        1600, 1200
    )
    c.cd()
    
    # Create 3D histogram for this side (all data combined)
    h_phi_r_t = root.TH3F(
        f"h_phi_r_t_side{side}_all",
        f"Side {side} - All Sectors & Modules; #phi (rad); radius (cm); timebin",
        150, 0, 2*np.pi,  # Full phi range (0 to 2π)
        150, 28, 78,      # Full radius range (covers all 3 modules)
        150, 0, 500       # timebin range
    )
    
    # Collect all hits for this side
    
    chain_counter = 0
    
    for sector in range(24):
        for imod in range(3):
            key = (sector, imod, side)
            
            if key not in all_chain_hits:
                continue
            
            chains = all_chain_hits[key]
            
            # Draw each chain
            for chain in chains:
                color = CHAIN_COLORS[chain_counter % len(CHAIN_COLORS)]
                
                # Separate main and support hits
                main_hits = [h for h in chain['hits'] if h['hit_type'] == 'main']
                support_hits = [h for h in chain['hits'] if h['hit_type'] == 'support']
                
                # Draw main hits (larger markers)
                if main_hits:
                    graph_main = root.TGraph2D(len(main_hits))
                    for i, hit in enumerate(main_hits):
                        graph_main.SetPoint(i, hit['phi'], hit['radius'], hit['timebin'])
                        h_phi_r_t.Fill(hit['phi'], hit['radius'], hit['timebin'], hit['adc'])
                    
                    graph_main.SetMarkerColor(color)
                    graph_main.SetMarkerStyle(20)
                    graph_main.SetMarkerSize(0.8)
                    if chain_counter == 0:
                        graph_main.Draw("P")
                    else:
                        graph_main.Draw("P SAME")
                
                # Draw support hits (smaller markers)
                if support_hits:
                    graph_support = root.TGraph2D(len(support_hits))
                    for i, hit in enumerate(support_hits):
                        graph_support.SetPoint(i, hit['phi'], hit['radius'], hit['timebin'])
                        h_phi_r_t.Fill(hit['phi'], hit['radius'], hit['timebin'], hit['adc'])
                    
                    graph_support.SetMarkerColor(color)
                    graph_support.SetMarkerStyle(24)  # open circle
                    graph_support.SetMarkerSize(0.5)
                    graph_support.Draw("P SAME")
                
                chain_counter += 1
    
    # Draw the 3D histogram
    h_phi_r_t.SetMarkerStyle(20)
    h_phi_r_t.SetMarkerSize(0.3)
    h_phi_r_t.Draw("SCAT")
    
    '''graph_all_hits = root.TGraph2D(len(cluster_points_cyl))
    for i in range(len(cluster_points_cyl)):
        graph_all_hits.SetPoint(i, 
                                cluster_points_cyl[i, 1], 
                                cluster_points_cyl[i, 0], 
                                cluster_points_cyl[i, 2])
        
    #graph_all_hits.SetMarkerStyle(7)  # small dots
    graph_all_hits.SetMarkerColor(root.kRed)
    graph_all_hits.Draw("P SAME")'''
    
    c.Update()
    vis_canvases.append(c)

# Draw all canvases
for c in vis_canvases:
    c.Draw()

print("\n" + "="*60)
print("VISUALIZATION COMPLETE")
print(f"Created {len(vis_canvases)} canvases (one per side)")
print("Each canvas shows ALL sectors and modules combined")
print("Main hits: filled circles (larger)")
print("Support hits: open circles (smaller)")
print("Phi range: 0 to 2π (full azimuthal coverage)")
print("="*60)


# Plot parameter space 

In [None]:
'''import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from scipy.spatial.distance import pdist, squareform
#import seaborn as sns

class TrackFitClusterer:
    """
    Cluster track fits based on their parameters in 5D space.
    Finds groups of tracks with similar fit parameters.
    """
    
    def __init__(self):
        self.fit_data = []
        self.scaler = StandardScaler()
        self.labels = None
        self.normalized_params = None
        
    def add_fit(self, fit_params, metadata=None):
        """
        Add a fit result from sagitta_fit_3d to the collection.
        
        Args:
            fit_params: dict returned from sagitta_fit_3d
            metadata: optional dict with additional info (e.g., event_id, track_id)
        """
        if fit_params is None:
            return
            
        fit_entry = {
            'pad_a': fit_params['pad_params'][0],
            'pad_b': fit_params['pad_params'][1],
            'pad_c': fit_params['pad_params'][2],
            'tbin_d': fit_params['tbin_params'][0],
            'tbin_e': fit_params['tbin_params'][1],
            'chi2_pad': fit_params['chi2_pad'],
            'chi2_tbin': fit_params['chi2_tbin'],
            'sagitta': fit_params['sagitta'],
            'n_hits': fit_params['n_hits'],
            'total_adc': fit_params['total_adc'],
        }
        
        if metadata:
            fit_entry.update(metadata)
            
        self.fit_data.append(fit_entry)
        
    def cluster(self, eps=0.5, min_samples=2, weights=None):
        """
        Perform clustering on collected fits.
        
        Args:
            eps: DBSCAN epsilon parameter (distance threshold in normalized space)
            min_samples: minimum number of points to form a cluster
            weights: dict to weight parameters differently, e.g., {'pad_a': 2.0, 'tbin_d': 1.0}
        
        Returns:
            labels: cluster assignments (-1 for noise)
        """
        if len(self.fit_data) < 2:
            print("Need at least 2 fits for clustering")
            return None
            
        # Extract parameters for clustering
        param_names = ['pad_a', 'pad_b', 'pad_c', 'tbin_d', 'tbin_e']
        X = np.array([[fit[p] for p in param_names] for fit in self.fit_data])
        
        # Normalize parameters
        self.normalized_params = self.scaler.fit_transform(X)
        
        # Apply weights if provided
        if weights:
            for i, param in enumerate(param_names):
                if param in weights:
                    self.normalized_params[:, i] *= weights[param]
        
        # Cluster using DBSCAN
        clusterer = DBSCAN(eps=eps, min_samples=min_samples)
        self.labels = clusterer.fit_predict(self.normalized_params)
        
        # Add cluster labels to fit_data
        for i, fit in enumerate(self.fit_data):
            fit['cluster'] = self.labels[i]
        
        return self.labels
    
    def get_cluster_stats(self):
        """Get statistics for each cluster"""
        if self.labels is None:
            print("Run cluster() first")
            return None
            
        unique_labels = set(self.labels)
        unique_labels.discard(-1)  # Remove noise label
        
        stats = {}
        for label in unique_labels:
            cluster_fits = [f for f in self.fit_data if f['cluster'] == label]
            
            stats[label] = {
                'n_fits': len(cluster_fits),
                'avg_sagitta': np.mean([f['sagitta'] for f in cluster_fits]),
                'std_sagitta': np.std([f['sagitta'] for f in cluster_fits]),
                'avg_chi2_pad': np.mean([f['chi2_pad'] for f in cluster_fits]),
                'avg_chi2_tbin': np.mean([f['chi2_tbin'] for f in cluster_fits]),
                'avg_n_hits': np.mean([f['n_hits'] for f in cluster_fits]),
                'avg_total_adc': np.mean([f['total_adc'] for f in cluster_fits]),
                'avg_params': {
                    'pad_a': np.mean([f['pad_a'] for f in cluster_fits]),
                    'pad_b': np.mean([f['pad_b'] for f in cluster_fits]),
                    'pad_c': np.mean([f['pad_c'] for f in cluster_fits]),
                    'tbin_d': np.mean([f['tbin_d'] for f in cluster_fits]),
                    'tbin_e': np.mean([f['tbin_e'] for f in cluster_fits]),
                }
            }
        
        # Add noise stats
        noise_fits = [f for f in self.fit_data if f['cluster'] == -1]
        if noise_fits:
            stats[-1] = {
                'n_fits': len(noise_fits),
                'avg_sagitta': np.mean([f['sagitta'] for f in noise_fits]),
            }
        
        return stats
    
    def plot_clusters(self, view='pad_a_vs_b', figsize=(15, 10)):
        """
        Visualize clusters in parameter space.
        
        Args:
            view: which parameters to plot
                'pad_a_vs_b', 'tbin_d_vs_e', 'sagitta_vs_chi2', 
                'pad_params', 'all'
        """
        if self.labels is None:
            print("Run cluster() first")
            return
        
        views = {
            'pad_a_vs_b': ('pad_a', 'pad_b', 'Pad Curvature (a)', 'Pad Slope (b)'),
            'tbin_d_vs_e': ('tbin_d', 'tbin_e', 'Tbin Slope (d)', 'Tbin Intercept (e)'),
            'sagitta_vs_chi2': ('sagitta', 'chi2_pad', 'Sagitta', 'Chi² (pad)'),
            'pad_params': ('pad_b', 'pad_c', 'Pad Slope (b)', 'Pad Intercept (c)'),
        }
        
        if view == 'all':
            fig, axes = plt.subplots(2, 2, figsize=figsize)
            axes = axes.flatten()
            
            for idx, (view_name, (x_key, y_key, x_label, y_label)) in enumerate(views.items()):
                self._plot_single_view(axes[idx], x_key, y_key, x_label, y_label)
            
            plt.tight_layout()
        else:
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            x_key, y_key, x_label, y_label = views[view]
            self._plot_single_view(ax, x_key, y_key, x_label, y_label)
        
        plt.show()
    
    def _plot_single_view(self, ax, x_key, y_key, x_label, y_label):
        """Helper to plot a single view"""
        x = [f[x_key] for f in self.fit_data]
        y = [f[y_key] for f in self.fit_data]
        
        unique_labels = set(self.labels)
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
        
        for label, color in zip(unique_labels, colors):
            if label == -1:
                # Noise points
                mask = self.labels == label
                ax.scatter(np.array(x)[mask], np.array(y)[mask], 
                          c='gray', marker='x', s=50, alpha=0.3, label='Noise')
            else:
                mask = self.labels == label
                ax.scatter(np.array(x)[mask], np.array(y)[mask],
                          c=[color], marker='o', s=100, alpha=0.7, 
                          label=f'Cluster {label}', edgecolors='black', linewidth=0.5)
        
        ax.set_xlabel(x_label, fontsize=11)
        ax.set_ylabel(y_label, fontsize=11)
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)
    
    def plot_distance_matrix(self):
        """Plot pairwise distance matrix in normalized parameter space"""
        if self.normalized_params is None:
            print("Run cluster() first")
            return
        
        # Compute pairwise distances
        distances = squareform(pdist(self.normalized_params, metric='euclidean'))
        
        # Sort by cluster label for better visualization
        sorted_indices = np.argsort(self.labels)
        distances_sorted = distances[sorted_indices][:, sorted_indices]
        labels_sorted = self.labels[sorted_indices]
        
        fig, ax = plt.subplots(figsize=(10, 8))
        im = ax.imshow(distances_sorted, cmap='viridis', aspect='auto')
        
        # Add cluster boundaries
        cluster_changes = np.where(np.diff(labels_sorted) != 0)[0] + 0.5
        for change in cluster_changes:
            ax.axhline(change, color='red', linewidth=2)
            ax.axvline(change, color='red', linewidth=2)
        
        plt.colorbar(im, ax=ax, label='Normalized Distance')
        ax.set_xlabel('Fit Index (sorted by cluster)', fontsize=11)
        ax.set_ylabel('Fit Index (sorted by cluster)', fontsize=11)
        ax.set_title('Pairwise Distance Matrix', fontsize=13, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    def print_summary(self):
        """Print summary of clustering results"""
        if self.labels is None:
            print("Run cluster() first")
            return
        
        stats = self.get_cluster_stats()
        
        print("="*70)
        print("TRACK FIT CLUSTERING SUMMARY")
        print("="*70)
        print(f"Total fits: {len(self.fit_data)}")
        print(f"Number of clusters: {len([l for l in set(self.labels) if l != -1])}")
        print(f"Noise points: {np.sum(self.labels == -1)}")
        print("="*70)
        
        for label in sorted(stats.keys()):
            if label == -1:
                print(f"\nNoise Points: {stats[label]['n_fits']} fits")
                print(f"  Avg Sagitta: {stats[label]['avg_sagitta']:.4f}")
            else:
                print(f"\nCluster {label}: {stats[label]['n_fits']} fits")
                print(f"  Sagitta: {stats[label]['avg_sagitta']:.4f} ± {stats[label]['std_sagitta']:.4f}")
                print(f"  Chi² (pad): {stats[label]['avg_chi2_pad']:.4f}")
                print(f"  Chi² (tbin): {stats[label]['avg_chi2_tbin']:.4f}")
                print(f"  Avg hits: {stats[label]['avg_n_hits']:.1f}")
                print(f"  Avg ADC: {stats[label]['avg_total_adc']:.1f}")
                print(f"  Avg Params:")
                print(f"    Pad: a={stats[label]['avg_params']['pad_a']:.5f}, "
                      f"b={stats[label]['avg_params']['pad_b']:.3f}, "
                      f"c={stats[label]['avg_params']['pad_c']:.2f}")
                print(f"    Tbin: d={stats[label]['avg_params']['tbin_d']:.3f}, "
                      f"e={stats[label]['avg_params']['tbin_e']:.2f}")
        print("="*70)


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

# Initialize clusterer
clusterer = TrackFitClusterer()

# Example: Process your fits from the sagitta_fit_3d function
# Assuming you have chains of hits and you run sagitta_fit_3d on each chain:

"""
for chain_id, chain in enumerate(your_chains):
    fit_params, fit_line = sagitta_fit_3d(chain, key, layers, pts_arr, payload)
    
    if fit_params is not None:
        # Add fit to clusterer with metadata
        clusterer.add_fit(fit_params, metadata={
            'chain_id': chain_id,
            'event_id': event_id  # if you have event information
        })
"""

# For demonstration, let's create some synthetic fits
print("Generating synthetic fit data for demonstration...")
np.random.seed(42)

# Simulate 3 types of tracks with different characteristics
track_types = [
    {'pad_a': 0.01, 'pad_b': 0.5, 'pad_c': 10, 'tbin_d': 2.0, 'tbin_e': 5},
    {'pad_a': -0.02, 'pad_b': -0.3, 'pad_c': 15, 'tbin_d': 1.5, 'tbin_e': 8},
    {'pad_a': 0.005, 'pad_b': 0.8, 'pad_c': 12, 'tbin_d': 2.5, 'tbin_e': 3},
]

for i in range(100):
    track_type = track_types[i % 3]
    noise = 0.3
    
    synthetic_fit = {
        'pad_params': [
            track_type['pad_a'] + np.random.randn() * 0.003,
            track_type['pad_b'] + np.random.randn() * 0.15,
            track_type['pad_c'] + np.random.randn() * 1.0,
        ],
        'tbin_params': [
            track_type['tbin_d'] + np.random.randn() * 0.15,
            track_type['tbin_e'] + np.random.randn() * 0.8,
        ],
        'chi2_pad': np.abs(np.random.randn() * 0.5 + 0.5),
        'chi2_tbin': np.abs(np.random.randn() * 0.5 + 0.5),
        'sagitta': np.abs(track_type['pad_a']) * 10 + np.abs(np.random.randn() * 0.3),
        'n_hits': int(np.random.randint(5, 15)),
        'total_adc': np.random.rand() * 500 + 500,
    }
    
    clusterer.add_fit(synthetic_fit, metadata={'track_id': i})

print(f"Added {len(clusterer.fit_data)} fits")

# Perform clustering
print("\nRunning clustering algorithm...")
labels = clusterer.cluster(eps=0.5, min_samples=3)

# Print summary
clusterer.print_summary()

# Visualize results
print("\nGenerating visualizations...")
clusterer.plot_clusters(view='all')
clusterer.plot_distance_matrix()

# Get cluster statistics for further analysis
stats = clusterer.get_cluster_stats()

# Example: Find all fits in a specific cluster
cluster_0_fits = [f for f in clusterer.fit_data if f['cluster'] == 0]
print(f"\nCluster 0 contains {len(cluster_0_fits)} fits")
print(f"Track IDs in Cluster 0: {[f['track_id'] for f in cluster_0_fits[:10]]}...")
'''

In [None]:
'''def print_final_chains_layer_by_layer(
    final_chains, layers,
    max_chains=None,
    sort_layers=True,
    layer_step_hint=None,
    layer_min=None,
    layer_max=None,
):
    def _hit_info(layer, j):
        ld = layers[int(layer)]
        gi = int(ld["idx"][int(j)])
        t  = int(ld["tp"][int(j), 0])
        p  = int(ld["tp"][int(j), 1])
        w  = float(ld["adc"][int(j)])
        return gi, t, p, w

    n = len(final_chains)
    nprint = n if max_chains is None else min(n, int(max_chains))

    print("\n" + "="*90)
    print(f"FINAL CHAINS: total={n}, printing={nprint}  (layer_min={layer_min}, layer_max={layer_max})")
    print("="*90)

    for ic in range(nprint):
        chain_main, chain_support = final_chains[ic]

        by_layer = {}
        for ly, j in chain_main:
            ly = int(ly); j = int(j)
            if layer_min is not None and ly < layer_min: continue
            if layer_max is not None and ly > layer_max: continue
            by_layer.setdefault(ly, {"main": [], "support": []})["main"].append(j)

        for ly, j in chain_support:
            ly = int(ly); j = int(j)
            if layer_min is not None and ly < layer_min: continue
            if layer_max is not None and ly > layer_max: continue
            by_layer.setdefault(ly, {"main": [], "support": []})["support"].append(j)

        if not by_layer:
            continue  # nothing in requested layer range

        layers_list = list(by_layer.keys())
        if sort_layers:
            if layer_step_hint in (+1, -1):
                layers_list.sort(reverse=(layer_step_hint == -1))
            else:
                layers_list.sort()

        n_main = sum(len(by_layer[ly]["main"]) for ly in layers_list)
        n_sup  = sum(len(by_layer[ly]["support"]) for ly in layers_list)

        print("\n" + "-"*90)
        print(f"Chain {ic}: n_main={n_main}  n_support={n_sup}  n_layers={len(layers_list)} (filtered)")
        print("-"*90)

        for ly in layers_list:
            mains = by_layer[ly]["main"]
            sups  = by_layer[ly]["support"]
            print(f"  Layer {ly:3d}: main={len(mains):2d}, support={len(sups):2d}")

            for j in sorted(mains):
                gi, t, p, w = _hit_info(ly, j)
                print(f"    M  j={j:4d}  gi={gi:7d}  t={t:5d}  p={p:4d}  adc={w:8.2f}")

            for j in sorted(sups):
                gi, t, p, w = _hit_info(ly, j)
                print(f"    S  j={j:4d}  gi={gi:7d}  t={t:5d}  p={p:4d}  adc={w:8.2f}")
'''

In [None]:
'''print_final_chains_layer_by_layer(
    final_vertical_chains, layers,
    layer_max=22,          # <-- only layers <= 22
    layer_step_hint=+1
)'''


In [None]:
'''# ============================================================
# SAVE ALL CHAINS (horizontal + vertical) FOR ALL
# sectors 0-11, sides 0-1, modules (from available keys)
# into ONE ROOT file that you can read from another notebook.
# ============================================================

import ROOT as root
from array import array

out_root = "/Users/mitrankova/Jupyter/PatternRecognition/all_chains.root"

# --- open output ROOT ---
fout = root.TFile.Open(out_root, "RECREATE")
if not fout or fout.IsZombie():
    raise RuntimeError(f"Cannot create {out_root}")

# -------------------------
# Flat per-hit tree (robust)
# -------------------------
th = root.TTree("ChainHits", "One entry per hit with chain metadata")

# scalars (arrays)
sec_a       = array('i', [0])
side_a      = array('i', [0])
mod_a       = array('i', [0])
chain_uid_a = array('i', [0])   # unique across whole file
chain_id_a  = array('i', [0])   # local id within (sec,side,mod,chain_type)
chain_type_a= array('i', [0])   # 0=horizontal, 1=vertical
is_main_a   = array('i', [0])   # 1=main, 0=support

layer_a     = array('i', [0])
j_a         = array('i', [0])
gi_a        = array('i', [0])
t_a         = array('i', [0])
p_a         = array('i', [0])
adc_a       = array('f', [0.0])

th.Branch("sec",       sec_a,        "sec/I")
th.Branch("side",      side_a,       "side/I")
th.Branch("mod",       mod_a,        "mod/I")
th.Branch("chain_uid", chain_uid_a,  "chain_uid/I")
th.Branch("chain_id",  chain_id_a,   "chain_id/I")
th.Branch("chain_type",chain_type_a, "chain_type/I")
th.Branch("is_main",   is_main_a,    "is_main/I")

th.Branch("layer",     layer_a,      "layer/I")
th.Branch("j",         j_a,          "j/I")
th.Branch("gi",        gi_a,         "gi/I")
th.Branch("t",         t_a,          "t/I")
th.Branch("p",         p_a,          "p/I")
th.Branch("adc",       adc_a,        "adc/F")

# -------------------------
# Optional per-chain summary
# -------------------------
tc = root.TTree("Chains", "One entry per chain summary")
sec2_a       = array('i', [0])
side2_a      = array('i', [0])
mod2_a       = array('i', [0])
chain_uid2_a = array('i', [0])
chain_id2_a  = array('i', [0])
chain_type2_a= array('i', [0])
n_main_a     = array('i', [0])
n_sup_a      = array('i', [0])
n_layers_a   = array('i', [0])

tc.Branch("sec",       sec2_a,       "sec/I")
tc.Branch("side",      side2_a,      "side/I")
tc.Branch("mod",       mod2_a,       "mod/I")
tc.Branch("chain_uid", chain_uid2_a, "chain_uid/I")
tc.Branch("chain_id",  chain_id2_a,  "chain_id/I")
tc.Branch("chain_type",chain_type2_a,"chain_type/I")
tc.Branch("n_main",    n_main_a,     "n_main/I")
tc.Branch("n_support", n_sup_a,      "n_support/I")
tc.Branch("n_layers",  n_layers_a,   "n_layers/I")

# -------------------------
# Helpers
# -------------------------
def _hit_fields(layers, ly, j):
    ld = layers[int(ly)]
    gi = int(ld["idx"][int(j)])
    tt = int(ld["tp"][int(j), 0])
    pp = int(ld["tp"][int(j), 1])
    ww = float(ld["adc"][int(j)])
    return gi, tt, pp, ww

def _fill_chain_hits(sec, side, mod, chain_uid, chain_id_local, chain_type,
                     chain_main, chain_support, layers):
    # per-chain summary
    sec2_a[0]        = int(sec)
    side2_a[0]       = int(side)
    mod2_a[0]        = int(mod)
    chain_uid2_a[0]  = int(chain_uid)
    chain_id2_a[0]   = int(chain_id_local)
    chain_type2_a[0] = int(chain_type)
    n_main_a[0]      = int(len(chain_main))
    n_sup_a[0]       = int(len(chain_support))

    # unique layers count across main+support
    layset = set(int(ly) for (ly, _) in chain_main)
    layset.update(int(ly) for (ly, _) in chain_support)
    n_layers_a[0] = int(len(layset))
    tc.Fill()

    # per-hit rows: MAIN
    for ly, j in chain_main:
        gi, tt, pp, ww = _hit_fields(layers, ly, j)
        sec_a[0]        = int(sec)
        side_a[0]       = int(side)
        mod_a[0]        = int(mod)
        chain_uid_a[0]  = int(chain_uid)
        chain_id_a[0]   = int(chain_id_local)
        chain_type_a[0] = int(chain_type)
        is_main_a[0]    = 1

        layer_a[0] = int(ly)
        j_a[0]     = int(j)
        gi_a[0]    = int(gi)
        t_a[0]     = int(tt)
        p_a[0]     = int(pp)
        adc_a[0]   = float(ww)
        th.Fill()

    # per-hit rows: SUPPORT
    for ly, j in chain_support:
        gi, tt, pp, ww = _hit_fields(layers, ly, j)
        sec_a[0]        = int(sec)
        side_a[0]       = int(side)
        mod_a[0]        = int(mod)
        chain_uid_a[0]  = int(chain_uid)
        chain_id_a[0]   = int(chain_id_local)
        chain_type_a[0] = int(chain_type)
        is_main_a[0]    = 0

        layer_a[0] = int(ly)
        j_a[0]     = int(j)
        gi_a[0]    = int(gi)
        t_a[0]     = int(tt)
        p_a[0]     = int(pp)
        adc_a[0]   = float(ww)
        th.Fill()

# -------------------------
# Main loop over all keys
# -------------------------
chain_uid = 0
n_keys = 0
n_chains_total = 0

# determine module range from hists keys if available
# expected key: (sec, imod, side)
mods_in_hists = sorted({int(k[1]) for k in hists.keys()}) if "hists" in globals() else [0,1,2]

for sec in range(12):
    for side in (0, 1):
        for imod in mods_in_hists:
            key = (sec, imod, side)
            if key not in hists:
                continue

            n_keys += 1

            # Build per-key layer structures
            layers = build_layer_indices_and_trees(key, pts_arr, payload)

            # Seeds
            seeds_sorted = find_seeds_sorted(layers)  # (layer, j, adc)
            seed_set = set((int(ly), int(j)) for (ly, j, adc) in seeds_sorted)

            # used within THIS (sec,side,mod) only
            used_global = set()

            # -------------------------
            # HORIZONTAL chains
            # -------------------------
            horizontal_chains = find_horizontal_chains_allhits(
                layers, seeds_sorted, used_global,
                dt=2, dp=3,
                min_hits=3,
                min_pad_span=5,
                order_for_drawing=True
            )

            # store horizontal: treat as "main only", no support
            chain_id_local = 0
            for ch in horizontal_chains:
                if len(ch) < 2:
                    continue
                _fill_chain_hits(
                    sec, side, imod,
                    chain_uid=chain_uid,
                    chain_id_local=chain_id_local,
                    chain_type=0,                # 0=horizontal
                    chain_main=ch,
                    chain_support=[],
                    layers=layers
                )
                chain_uid += 1
                chain_id_local += 1
                n_chains_total += 1

            # -------------------------
            # VERTICAL chains (two pass)
            # -------------------------
            final_vertical_chains = two_pass_chaining(
                layers, seeds_sorted, seed_set,
                dt_win0=3, dp_win0=2,
                dt_gate=1, dp_gate=1,
                allow_skip_one_layer=True
                # If you upgraded your two_pass_chaining to accept used_global, pass it here:
                # , used_global=used_global
            )

            chain_id_local = 0
            for (chain_main, chain_support) in final_vertical_chains:
                if len(chain_main) < 2:
                    continue
                _fill_chain_hits(
                    sec, side, imod,
                    chain_uid=chain_uid,
                    chain_id_local=chain_id_local,
                    chain_type=1,                # 1=vertical
                    chain_main=chain_main,
                    chain_support=chain_support,
                    layers=layers
                )
                chain_uid += 1
                chain_id_local += 1
                n_chains_total += 1

            if (n_keys % 10) == 0:
                print(f"[progress] keys processed={n_keys}, chains stored={n_chains_total}")

# write and close
fout.cd()
tc.Write()
th.Write()
fout.Close()

print("============================================================")
print(f"Saved to: {out_root}")
print(f"Keys processed: {n_keys}")
print(f"Chains stored:  {n_chains_total}")
#print(f"Total hits rows in ChainHits: {th.GetEntries()}")
print("============================================================")
'''