# Converting data into pytorch geometric format

* **Status**: Active

**Code adapted from Jonathan's https://gist.github.com/jkguiang/ea0c7438e76efa61a29a8576b7781cce**

In [22]:
from tqdm.auto import tqdm

import uproot
import torch
from torch_geometric.data import Data
from pathlib import Path
from torch import Tensor as T

In [23]:
import dataclasses


@dataclasses.dataclass(kw_only=True)
class IngressColumnConfig:
    edge_truth_label: str
    edge_indices: tuple[str, str]
    edge_features: list[str]
    node_features: list[str]
    particle_id: str
    truth_pt: str
    truth_eta: str

In [24]:
icc = IngressColumnConfig(
    edge_truth_label="LS_isFake",
    edge_indices=("LS_MD_idx0", "LS_MD_idx1"),
    edge_features=["LS_pt", "LS_eta", "LS_phi"],
    node_features=[
        "MD_0_x",
        "MD_0_y",
        "MD_0_z",
        "MD_1_x",
        "MD_1_y",
        "MD_1_z",
        "MD_dphichange",
        "MD_phi",
        "MD_eta",
    ],
    particle_id="MD_sim_idx",
    truth_pt="MD_sim_pt",
    truth_eta="MD_sim_eta",
)

In [25]:
def r_phi_eta(x: T, y: T, z: T) -> tuple[T, T, T]:
    r = torch.sqrt(x**2 + y**2)
    phi = torch.arctan2(y, x)
    theta = torch.arctan2(r, z)
    eta = -torch.log(torch.tan(theta / 2.0))
    return r, phi, eta

In [30]:
class LSDatasetConverter:
    def __init__(self, columns: IngressColumnConfig):
        self._cc = columns

    def convert_batch(
        self,
        batch,
    ):
        # Get indices of nodes connected by each edge
        edge_idxs = torch.tensor(
            [batch[n].to_list() for n in self._cc.edge_indices],
            dtype=torch.long,
        )

        # Get edge features
        edge_attr = []
        for branch_name in self._cc.edge_features:
            feature = torch.tensor(batch[branch_name].to_list(), dtype=torch.float)
            feature[torch.isinf(feature)] = feature[~torch.isinf(feature)].max()
            edge_attr.append(feature)

        edge_attr = torch.transpose(torch.stack(edge_attr), 0, 1)

        # Get node features
        node_attr = []
        for branch_name in self._cc.node_features:
            feature = torch.tensor(batch[branch_name].to_list(), dtype=torch.float)
            node_attr.append(feature)

        node_attr = torch.transpose(torch.stack(node_attr), 0, 1)
        particle_id = torch.tensor(
            batch[self._cc.particle_id].to_list(), dtype=torch.long
        )

        # Get truth labels
        truth = (particle_id[edge_idxs[0]] == particle_id[edge_idxs[1]]) & (
            particle_id[edge_idxs[0]] >= 0
        )

        truth_pt = torch.tensor(batch[self._cc.truth_pt].to_list(), dtype=torch.float)
        truth_eta = torch.tensor(batch[self._cc.truth_eta].to_list(), dtype=torch.float)
        reconstructable = torch.ones_like(particle_id, dtype=torch.bool)
        sector = torch.full_like(
            particle_id,
            0,
        )

        return Data(
            x=node_attr,
            y=truth,
            edge_index=edge_idxs,
            edge_attr=edge_attr,
            particle_id=particle_id,
            pt=truth_pt,
            eta=truth_eta,
            sector=sector,
            reconstructable=reconstructable,
        )

    def ingress(
        self,
        *,
        input_file: Path,
        out_dir: Path,
        tree_name: str = "tree",
        branch_filter: str = "/(MD|LS|sim)_*/",
        redo=True,
    ):
        tree = uproot.open(f"{input_file}:{tree_name}")
        inpt = tree.iterate(
            step_size=1,
            filter_name=branch_filter,
        )
        # fixme: Len(tree) actually isn't accurate
        iterator = tqdm(enumerate(inpt), total=len(tree))

        for i, batch in iterator:
            out_path = out_dir / f"{i:04d}.pt"
            if not redo and out_path.is_file():
                # not ideal, because we're still reading the file which
                # amounts to most of the time
                continue
            batch = batch[0, :]  # only one event per batch
            data = self.convert_batch(batch)
            out_dir.mkdir(parents=True, exist_ok=True)
            torch.save(data, out_path)

In [31]:
lsdc = LSDatasetConverter(icc)

In [32]:
lsdc.ingress(
    input_file=Path(
        "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v1/root/LSTNtuple_DNNTraining_hasT5Chi2_PU200.root"
    ),
    out_dir=Path(
        "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v1/processed"
    ),
    redo=True,
)

  0%|          | 0/207 [00:00<?, ?it/s]

In [11]:
tree = uproot.open(
    f"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v1/root/LSTNtuple_DNNTraining_hasT5Chi2_PU200.root:tree"
)

In [12]:
tree

<TTree 'tree' (207 branches) at 0x1467f590b7f0>

## Deep dive into truth label

In [14]:
dpath = Path(
    "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v1/processed"
)

In [15]:
data = torch.load(list(dpath.glob("*.pt"))[0])

In [16]:
data.y.bool()

tensor([ True,  True,  True,  ..., False, False,  True])

In [17]:
real_y = data.particle_id[data.edge_index[0]] == data.particle_id[data.edge_index[1]]

In [18]:
real_y

tensor([False, False,  True,  ..., False,  True, False])

In [21]:
real_y[data.y.bool()].float().mean()

tensor(0.1674)