In [3]:
from gnn_tracking.utils.loading import TrackingDataModule

dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_1/"
        ],
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/"
        ],
        stop=5
    ),
)
dm.setup(stage="fit")

[32m[20:15:26] INFO: DataLoader will load 28800 graphs (out of 28800 available).[0m
[36m[20:15:26] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_1/data21999_s9.pt[0m
[32m[20:15:27] INFO: DataLoader will load 5 graphs (out of 32000 available).[0m
[36m[20:15:27] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29000_s12.pt[0m


In [1]:
import networkx as nx
from torch import Tensor as T

def get_cc_labels(edge_index: T, num_nodes: int) -> T:

    gx = nx.Graph()
    gx.add_nodes_from(list(range(num_nodes)))
    gx.add_edges_from(edge_index.T.detach().cpu().numpy())
    components = nx.connected_components(gx)
    index_mapping = {node: index for index,node_set in enumerate(components) for node in node_set}
    return T([index_mapping[node] for node in gx.nodes()])


In [151]:
from gnn_tracking.metrics.cluster_metrics import TrackingMetrics, tracking_metrics, \
    tracking_metric_df, ClusterMetricType, _tracking_metrics_nan_results, \
    count_tracking_metrics
from typing import Iterable
from torch_geometric.data import Data
import numpy as np
import pandas as pd



def tracking_metrics(
    *,
    truth: np.ndarray,
    predicted: np.ndarray,
    pts: np.ndarray,
    reconstructable: np.ndarray,
    eta: np.ndarray,
    pt_thlds: Iterable[float],
    predicted_count_thld=3,
    max_eta=4,
) -> dict[float, TrackingMetrics]:
    """Calculate 'custom' metrics for matching tracks and hits.

    Args:
        truth: Truth labels/PIDs for each hit
        predicted: Predicted labels/cluster index for each hit. Negative labels are
            interpreted as noise (because this is how DBSCAN outputs it) and are
            ignored
        pts: true pt value of particle belonging to each hit
        reconstructable: Whether the hit belongs to a "reconstructable tracks" (this
            usually implies a cut on the number of layers that are being hit
            etc.)
        eta: true pseudorapidity of particle belong to each hit
        pt_thlds: pt thresholds to calculate the metrics for
        predicted_count_thld: Minimal number of hits in a cluster for it to not be
            rejected.
        max_eta: Maximum eta value to count

    Returns:
        See `TrackingMetrics`
    """
    for ar in (truth, predicted, pts, reconstructable, eta):
        # Tensors behave differently when counting, so this is absolutely vital!
        assert isinstance(ar, np.ndarray)
    assert predicted.shape == truth.shape == pts.shape, (
        predicted.shape,
        truth.shape,
        pts.shape,
        eta.shape,
    )
    if len(truth) == 0:
        return {pt: _tracking_metrics_nan_results for pt in pt_thlds}
    pids, counts = np.unique(truth, return_counts=True)
    pid2count = dict(zip(pids, counts))
    count_ar = np.array([pid2count[pid] for pid in truth])
    h_df = pd.DataFrame(
        {
            "c": predicted,
            "id": truth,
            "pt": pts,
            "reconstructable": reconstructable,
            "eta": eta,
            "n_hits": count_ar
        }
    )
    c_df = tracking_metric_df(h_df, predicted_count_thld=predicted_count_thld)

    result = dict[float, ClusterMetricType]()
    for pt in pt_thlds:
        c_mask = (
            (c_df["maj_pt"] >= pt)
            & c_df["maj_reconstructable"]
            & (c_df["maj_eta"].abs() < max_eta)
            & c_df["valid_cluster"]
        )
        h_mask = (
            (h_df["pt"] >= pt)
            & h_df["reconstructable"].astype(bool)
            & (h_df["eta"].abs() < max_eta)
            & (h_df["n_hits"] >= predicted_count_thld)
        )

        r = count_tracking_metrics(c_df, h_df, c_mask, h_mask)
        result[pt] = r  # type: ignore
    return c_df, result  # type: ignore


def tracking_metrics_data(data: Data, labels, pt_thlds: Iterable[float],
    predicted_count_thld=3,
    max_eta=4,):
    return tracking_metrics(
        truth=data.particle_id.detach().cpu().numpy(),
        predicted=labels,
        pts=data.pt.detach().cpu().numpy(),
        reconstructable=data.reconstructable.detach().cpu().numpy(),
        eta=data.eta.detach().cpu().numpy(),
        pt_thlds=pt_thlds,
        max_eta=max_eta,
        predicted_count_thld=predicted_count_thld,
    )

In [152]:
def get_best_truth_metrics(data):
    ei = data.edge_index
    y = data.y.bool()
    tl = get_cc_labels(ei[:, y], num_nodes=data.num_nodes).long()
    return tracking_metrics_data(data, tl.detach().cpu().numpy(), [0.9])

In [34]:
from gnn_tracking.analysis.graphs import get_largest_segment_fracs

In [35]:
data = dm.datasets["train"][0]
data.y = data.y.bool()

In [150]:
from collections import defaultdict
from typing import Set
from gnn_tracking.utils.graph_masks import get_good_node_mask
import torch


def get_largest_segment_fracs(
    data: Data,
    *,
    pt_thld=0.9,
    n_particles_sampled=None,
    max_eta=4,
    count_thld=0,
) :
    """A fast way to get the fraction of hits in the largest segment for each particle.

    Args:
        data:
        pt_thld:
        n_particles_sampled: If not None, only consider a subsample of the particles.
            This speeds up calculation but introduces statistical fluctuations.
        max_eta: Maximum pseudorapidity
        count_thld: Minimum size of segment to be considered.

    Returns:
        Array of fractions.
    """
    # This implementation simply looks at the connected components for a graph
    # with all true edges stripped (so connected component = segment).
    basic_hit_mask = get_good_node_mask(data, pt_thld=pt_thld, max_eta=max_eta)
    unique_pids, counts = torch.unique(
        data.particle_id[basic_hit_mask], return_counts=True
    )
    pid2count = dict(zip(unique_pids.tolist(), counts.tolist()))
    if n_particles_sampled is not None:
        rand_perm = torch.randperm(len(unique_pids))
        unique_pids = unique_pids[rand_perm][:n_particles_sampled]
        basic_hit_mask &= torch.isin(data.particle_id, unique_pids)
    rdata = Data(
        edge_index=data.edge_index[:, data.y],
        particle_id=data.particle_id,
        num_nodes=len(data.particle_id),
    ).subgraph(basic_hit_mask)
    gx = nx.Graph()
    gx.add_edges_from(rdata.edge_index.T.detach().cpu().numpy())
    segments: list[Set[int]] = nx.connected_components(gx)
    pid_to_largest_segment = defaultdict(int)
    for segment in segments:
        if len(segment) < count_thld:
            continue
        # PID is the same for all nodes in connected component by construction
        pid = rdata.particle_id[next(iter(segment))].item()
        assert (rdata.particle_id[list(segment)] == pid).all()
        pid_to_largest_segment[pid] = max(
            pid_to_largest_segment[pid], len(segment) / pid2count[pid]
        )
    return pid_to_largest_segment, np.array(list(pid_to_largest_segment.values()))

In [103]:
cdf[cdf["maj_pid"] == 247700453406539776]

Unnamed: 0_level_0,maj_pid,maj_hits,cluster_size,valid_cluster,maj_reconstructable,maj_eta,maj_pt,maj_pid_hits,maj_frac,maj_pid_frac,perfect_match,double_majority,lhc_match
c,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
364,247700453406539776,3,3,True,1.0,2.433346,0.915068,7,1.0,0.428571,False,False,True
739,247700453406539776,4,4,True,1.0,2.433346,0.915068,7,1.0,0.571429,False,True,True


In [154]:
cdf, mtrx = get_best_truth_metrics(data)
mtrx

{0.9: {'n_particles': 68,
  'n_cleaned_clusters': 70,
  'perfect': 0.8676470588235294,
  'double_majority': 0.9411764705882353,
  'lhc': 1.0,
  'fake_perfect': 0.16176470588235295,
  'fake_double_majority': 0.08823529411764706,
  'fake_lhc': 0.0}}

In [155]:
lfdct, lsf = get_largest_segment_fracs(data, count_thld=3)

In [156]:
len(lfdct)

65

In [158]:
(lsf > 0.5).sum() / 68

0.9411764705882353

In [146]:
for pid in lfdct:
    if lfdct[pid] <= 0.5:
        continue
    sel = cdf[cdf["maj_pid"] == pid]
    if len(sel) == 0:
        print(1, pid)
    if not sel["double_majority"].any():
        print(2, pid)

In [None]:
# fixme: Reconstructable doesn't have the count >= 3 cut in it?

## Check sector PID overlap

In [1]:
from pathlib import Path
base_path = Path("/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/")

In [14]:
example_data = list(base_path.glob("data29004_*"))

In [15]:
example_data

[PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s28.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s30.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s2.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s21.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s15.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s18.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s26.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s4.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v5/part_9/data29004_s0.pt'),
 PosixPath('/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/

In [16]:
import torch

In [17]:
seen = set()
overlapped = set()
for p in example_data:
    data = torch.load(p)
    new = set(torch.unique(data.particle_id).tolist())
    overlapped |= seen & new
    seen |= new

In [18]:
len(overlapped) / len(seen)

0.6560659599528857