## _Graph Intersection_

Our goal is to build a _**labelled dataset**_ in the for of input (`edge_index`) and targets (`y`) for edge classification: _`true_edges`, `input_edges` $\rightarrow$ `edge_index`, `y`_. We have

- _`true_edges` are truth graph from `layerwise_true_edges` or `modulewise_true_edges()`_
- _`input_edge` are input graph from Heuristic Method_

and we like to build _**labelled dataset** [`edge_index`,`y`]_ using _`graph_intersection()`_.



In [None]:
import glob, os, sys, yaml

In [None]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
import pprint
import seaborn as sns
import trackml.dataset

In [None]:
import torch
from torch_geometric.data import Data
import itertools

In [None]:
# append parent dir
sys.path.append("..")

In [None]:
# get cuda device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# local imports
from src import SttCSVDataReader, SttTorchDataReader
from src import detector_layout
from src import Build_Event, Build_Event_Viz, Visualize_Edges
from src.math_utils import polar_to_cartesian

### _Input Data_

In [None]:
# input data
input_dir = "../data_all"

In [None]:
# Find All Input Data Files (hits.csv, cells.csv, particles.csv, truth.csv)
all_files = os.listdir(input_dir)

# Extract File Prefixes (use e.g. xxx-hits.csv)
suffix = "-hits.csv"
file_prefixes = sorted(
    os.path.join(input_dir, f.replace(suffix, ""))
    for f in all_files
    if f.endswith(suffix)
)

print("Number of Files: ", len(file_prefixes))

In [None]:
# file_prefixes[:10]

In [None]:
# load an event
# hits, tubes, particles, truth = trackml.dataset.load_event(file_prefixes[0])

In [None]:
# hits.head()
# tubes.head()
# particles.head()
# truth.head()

### _Visualize Event_

- _`Build_Event()` is same as `select_hits()` in `processing/utils/event_utils.py`_

In [None]:
# select event
event_id = 95191

In [None]:
# compose event is exactly the same as select_hits()
# event = Build_Event(input_dir, event_id, noise=False, skewed=False, selection=False)

In [None]:
# visualize event
# Build_Event_Viz(event, figsize=(10,10), fig_type="pdf", save_fig=False)

## _Graph Intersection_

- _`true_edges`, `input_edges` $\rightarrow$ `edge_index`, `y`_

In [None]:
from LightningModules.Processing.utils.event_utils import select_hits
from LightningModules.Processing.utils.event_utils import get_layerwise_edges
from LightningModules.Processing.utils.graph_utils import get_input_edges

In [None]:
# get event prefix using event_id
event_prefix = file_prefixes[event_id]

In [None]:
# select hits
kwargs = {"selection": False}
event = select_hits(event_file=event_prefix, noise=False, skewed=False, **kwargs)

### _(A) - True Edges (Layerwise)_

In [None]:
# get true edges with new hits (changed)
true_edges, hits = get_layerwise_edges(event)

In [None]:
# split as sender and recivers
senders, receivers = true_edges

In [None]:
senders.shape, receivers.shape

In [None]:
# visualize nodes and edges
Visualize_Edges(hits, true_edges, figsize=(10, 10), fig_type="pdf", save_fig=False)

### _(B) - Input Edges (Layerwise)_

In [None]:
# get input Edges
input_edges = get_input_edges(hits, filtering=True)

In [None]:
# split as sender and recivers
senders, receivers = input_edges

In [None]:
senders.shape, receivers.shape

In [None]:
# visualize nodes and edges
# Visualize_Edges (hits, input_edges, figsize=(10,10), fig_type="pdf", save_fig=False)

### _(C) - Labelled Dataset_

- _use `true_edges`, `input_edges` to build `edge_index`, `y`. Note labelled dataset is `[inputs, targets]` $\rightarrow$ `[edge_index, y]`_

In [None]:
def graph_intersection(pred_graph, truth_graph):
    """Get truth information about edge_index (function is from both Embedding/Filtering)"""

    array_size = max(pred_graph.max().item(), truth_graph.max().item()) + 1

    if torch.is_tensor(pred_graph):
        l1 = pred_graph.cpu().numpy()
    else:
        l1 = pred_graph

    if torch.is_tensor(truth_graph):
        l2 = truth_graph.cpu().numpy()
    else:
        l2 = truth_graph

    e_1 = sp.sparse.coo_matrix(
        (np.ones(l1.shape[1]), l1), shape=(array_size, array_size)
    ).tocsr()

    e_2 = sp.sparse.coo_matrix(
        (np.ones(l2.shape[1]), l2), shape=(array_size, array_size)
    ).tocsr()

    del l1
    del l2

    e_intersection = (e_1.multiply(e_2) - ((e_1 - e_2) > 0)).tocoo()

    del e_1
    del e_2

    new_pred_graph = (
        torch.from_numpy(np.vstack([e_intersection.row, e_intersection.col]))
        .long()
        .to(device)
    )

    y = torch.from_numpy(e_intersection.data > 0).to(device)

    del e_intersection

    return new_pred_graph, y

In [None]:
# returns sorted input_graph
edge_index, y = graph_intersection(input_edges, true_edges)

- check the shape of tensors

In [None]:
edge_index.shape, y.shape

- extract true and false edges

In [None]:
# get true edges
true_edge_mask = y.bool()  # convert to boolean mask
true_edges = edge_index[:, true_edge_mask]  # filter true edges

In [None]:
true_edges.shape

In [None]:
# get false edges
false_edge_mask = ~(y.bool())  # convert to boolean mask
false_edges = edge_index[:, false_edge_mask]  # filter true edges

In [None]:
false_edges.shape

- Original number of true edges were 465, after grap intersection the extracted ones are now 462.
- Why some edges are missing?