In [13]:
import sys
from gnn_tracking.metrics.losses.metric_learning import (
    GraphConstructionHingeEmbeddingLoss,
)

sys.path.append("/home/kl5675/Documents/23/git_sync/gnn_tracking/tests")

In [59]:
from test_losses import td1, get_ml_loss, generate_test_data, MockData
import numpy as np

In [68]:
from gnn_tracking.preprocessing.point_cloud_builder import get_truth_edge_index

def generate_test_data(
    n_nodes=1000, n_particles=250, n_x_features=3, rng=None
) -> MockData:
    if rng is None:
        rng = np.random.default_rng()

    # no noise
    pid = torch.from_numpy(rng.choice(np.arange(n_particles), size=n_nodes))
    pid_unique = torch.unique(pid)
    # no low pt
    pt_pid = 1 + torch.from_numpy(2 * rng.random(len(pid_unique)))
    pt = pt_pid[pid]
    # no low eta
    eta_pid = 0.1 * torch.from_numpy(8 * (rng.random(len(pid_unique)) - 0.5))
    eta = eta_pid[pid]
    # no non-reco
    reco_pid = torch.from_numpy(rng.choice([1.0], size=len(pid_unique)))
    reco = reco_pid[pid]

    return MockData(
        beta=torch.from_numpy(rng.random(n_nodes)),
        x=torch.from_numpy(rng.random((n_nodes, n_x_features))),
        particle_id=pid,
        pred=torch.from_numpy(rng.choice([0.0, 1.0], size=(n_nodes, 1))),
        truth=torch.from_numpy(rng.choice([0.0, 1.0], size=(n_nodes, 1))),
        pt=pt,
        eta=eta,
        reconstructable=reco,
        batch=torch.zeros_like(reco),
        true_edge_index=torch.from_numpy(get_truth_edge_index(pid.numpy())),
    )


td1 = generate_test_data(50, n_particles=3, rng=np.random.default_rng(seed=0))

In [69]:
td1 = generate_test_data(50, n_particles=3, rng=np.random.default_rng(seed=0))

In [70]:
# td1.reconstructable = True
# all eta is already < 0
# td1.particle_id[td1.particle_id == 0] = 1
# assert td1.particle_id.min() > 0

In [71]:
get_ml_loss(GraphConstructionHingeEmbeddingLoss(rep_normalization="n_att_edges", rep_oi_only=False), td1)

att_edges.shape=torch.Size([2, 227])
rep_edges.shape=torch.Size([2, 1430])
norm_att=227.000000001
v_att*norm_att=tensor(151.7075, dtype=torch.float64)


{'attractive': 0.6683151300942156, 'repulsive': 2.247088209688136}

In [72]:
get_ml_loss(OldGraphConstructionHingeEmbeddingLoss(), td1)

true_edge_index.shape=torch.Size([2, 227])
true_edge_mask.sum()=tensor(227)
true_edge.sum()=tensor(425)
normalization=tensor(425.)
dists[~true_edge].shape=torch.Size([1720])
v_att_sum=tensor(271.3724, dtype=torch.float64)


{'attr': 0.6385233145469406, 'rep': 1.4239822355706957}

In [28]:
def _hinge_loss_components(
    *,
    x: T,
    att_edges: T,
    rep_edges: T,
    r_emb_hinge: float,
    p_attr: float,
    p_rep: float,
    n_hits_oi: int,
    normalization: str,
) -> tuple[T, T]:
    eps = 1e-9

    print(f"{att_edges.shape=}")
    print(f"{rep_edges.shape=}")
    dists_att = norm(x[att_edges[0]] - x[att_edges[1]], dim=-1)
    norm_att = att_edges.shape[1] + eps
    print(f"{norm_att=}")
    v_att = torch.sum(torch.pow(dists_att, p_attr)) / norm_att
    print(f"{v_att*norm_att=}")

    dists_rep = norm(x[rep_edges[0]] - x[rep_edges[1]], dim=-1)
    # There is no "good" way to normalize this: The naive way would be
    # to normalize to the number of repulsive edges, but this number
    # gets smaller and smaller as the training progresses, making the objective
    # increasingly harder.
    # The maximal number of edges that can be in the radius graph is proportional
    # to the number of hits of interest, so we normalize by this number.
    if normalization == "n_rep_edges":
        norm_rep = rep_edges.shape[1] + eps
    elif normalization == "n_hits_oi":
        norm_rep = n_hits_oi + eps
    elif normalization == "n_att_edges":
        norm_rep = att_edges.shape[1] + eps
    else:
        msg = f"Normalization {normalization} not recognized."
        raise ValueError(msg)

    # Note: Relu necessary for p < 1
    v_rep = (
        torch.sum(torch.nn.functional.relu(r_emb_hinge - torch.pow(dists_rep, p_rep)))
        / norm_rep
    )

    return v_att, v_rep


class GraphConstructionHingeEmbeddingLoss(MultiLossFct, HyperparametersMixin):
    # noinspection PyUnusedLocal
    def __init__(
        self,
        *,
        lw_repulsive: float = 1.0,
        r_emb: float = 1.0,
        max_num_neighbors: int = 256,
        pt_thld: float = 0.9,
        max_eta: float = 4.0,
        p_attr: float = 1.0,
        p_rep: float = 1.0,
        rep_normalization: str = "n_hits_oi",
        rep_oi_only: bool = True,
    ):
        """Loss for graph construction using metric learning.

        Args:
            lw_repulsive: Loss weight for repulsive part of potential loss
            r_emb: Radius for edge construction
            max_num_neighbors: Maximum number of neighbors in radius graph building.
                See https://github.com/rusty1s/pytorch_cluster#radius-graph
            pt_thld: pt threshold for particles of interest
            max_eta: maximum eta for particles of interest
            p_attr: Power for the attraction term (default 1: linear loss)
            p_rep: Power for the repulsion term (default 1: linear loss)
            normalization: Normalization for the repulsive term. Can be either
                "n_rep_edges" (normalizes by the number of repulsive edges < r_emb) or
                "n_hits_oi" (normalizes by the number of hits of interest) or
                "n_att_edges" (normalizes by the number of attractive edges of interest)
            rep_oi_only: Only consider repulsion between hits if at least one
                of the hits is of interest
        """
        super().__init__()
        self.save_hyperparameters()

    def _get_edges(
        self, *, x: T, batch: T, true_edge_index: T, mask: T, particle_id: T
    ) -> tuple[T, T]:
        """Returns edge index for graph"""
        near_edges = radius_graph(
            x,
            r=self.hparams.r_emb,
            batch=batch,
            loop=False,
            max_num_neighbors=self.hparams.max_num_neighbors,
        )
        # Every edge has to start at a particle of interest, so no special
        # case with noise
        if self.hparams.rep_oi_only:
            rep_edges = near_edges[:, mask[near_edges[0]]]
        else:
            rep_edges = near_edges
        rep_edges = rep_edges[:, particle_id[rep_edges[0]] != particle_id[rep_edges[1]]]
        att_edges = true_edge_index[:, mask[true_edge_index[0]]]
        return att_edges, rep_edges

    # noinspection PyUnusedLocal
    def forward(
        self,
        *,
        x: T,
        particle_id: T,
        batch: T,
        true_edge_index: T,
        pt: T,
        eta: T,
        reconstructable: T,
        **kwargs,
    ) -> MultiLossFctReturn:
        if true_edge_index is None:
            msg = (
                "True_edge_index must be given and not be None. Are you trying to use "
                "this loss for OC training? In this case, double check that you are "
                "properly passing on the true edges."
            )
            raise ValueError(msg)
        mask = get_good_node_mask_tensors(
            pt=pt,
            particle_id=particle_id,
            reconstructable=reconstructable,
            eta=eta,
            pt_thld=self.hparams.pt_thld,
            max_eta=self.hparams.max_eta,
        )
        # oi = of interest
        n_hits_oi = mask.sum()
        att_edges, rep_edges = self._get_edges(
            x=x,
            batch=batch,
            true_edge_index=true_edge_index,
            mask=mask,
            particle_id=particle_id,
        )
        attr, rep = _hinge_loss_components(
            x=x,
            att_edges=att_edges,
            rep_edges=rep_edges,
            r_emb_hinge=self.hparams.r_emb,
            p_attr=self.hparams.p_attr,
            p_rep=self.hparams.p_rep,
            n_hits_oi=n_hits_oi,
            normalization=self.hparams.rep_normalization,
        )
        losses = {
            "attractive": attr,
            "repulsive": rep,
        }
        weights: dict[str, float] = {
            "attractive": 1.0,
            "repulsive": self.hparams.lw_repulsive,
        }
        extra = {
            "n_hits_oi": n_hits_oi,
            "n_edges_att": att_edges.shape[1],
            "n_edges_rep": rep_edges.shape[1],
        }
        return MultiLossFctReturn(
            loss_dct=losses,
            weight_dct=weights,
            extra_metrics=extra,
        )

In [33]:
import torch
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from torch import Tensor as T
from torch.linalg import norm
from torch_cluster import radius_graph

from gnn_tracking.metrics.losses import MultiLossFct, MultiLossFctReturn
from gnn_tracking.utils.graph_masks import get_good_node_mask_tensors
from torch.nn.functional import relu
from torch import nn

def _old_hinge_loss_components(
    *,
    x: T,
    edge_index: T,
    particle_id: T,
    pt: T,
    r_emb_hinge: float,
    pt_thld: float,
    p_attr: float,
    p_rep: float,
) -> tuple[T, T]:
    true_edge = (particle_id[edge_index[0]] == particle_id[edge_index[1]]) & (
        particle_id[edge_index[0]] > 0
    )
    true_high_pt_edge = true_edge & (pt[edge_index[0]] > pt_thld)
    dists = norm(x[edge_index[0]] - x[edge_index[1]], dim=-1)
    normalization = true_high_pt_edge.sum() + 1e-8
    print(f"{true_edge.sum()=}")
    print(f"{normalization=}")
    print(f"{dists[~true_edge].shape=}")
    v_att_sum = torch.sum(
        torch.pow(dists[true_high_pt_edge], p_attr)
    )
    print(f"{v_att_sum=}")
    return torch.sum(
        torch.pow(dists[true_high_pt_edge], p_attr)
    ) / normalization, torch.sum(
        relu(r_emb_hinge - torch.pow(dists[~true_edge], p_rep)) / normalization
    )


class OldGraphConstructionHingeEmbeddingLoss(nn.Module, HyperparametersMixin):
    # noinspection PyUnusedLocal
    def __init__(
        self,
        *,
        r_emb=1,
        max_num_neighbors: int = 256,
        attr_pt_thld: float = 0.9,
        p_attr: float = 1,
        p_rep: float = 1,
    ):
        """Loss for graph construction using metric learning.

        Args:
            r_emb: Radius for edge construction
            max_num_neighbors: Maximum number of neighbors in radius graph building.
                See https://github.com/rusty1s/pytorch_cluster#radius-graph
            p_attr: Power for the attraction term (default 1: linear loss)
            p_rep: Power for the repulsion term (default 1: linear loss)
        """
        super().__init__()
        self.save_hyperparameters()

    def _build_graph(self, x: T, batch: T, true_edge_index: T, pt: T) -> T:
        true_edge_mask = pt[true_edge_index[0]] > self.hparams.attr_pt_thld
        near_edges = radius_graph(
            x,
            r=self.hparams.r_emb,
            batch=batch,
            loop=False,
            max_num_neighbors=self.hparams.max_num_neighbors,
        )
        print(f"{true_edge_index.shape=}")
        print(f"{true_edge_mask.sum()=}")
        # return torch.unique(
        #     torch.cat([true_edge_index[:, true_edge_mask], near_edges], dim=-1), dim=-1
        # )
        return torch.unique(torch.cat([true_edge_index[:, true_edge_mask], near_edges], dim=-1), dim=-1)
    # noinspection PyUnusedLocal
    def forward(
        self, *, x: T, particle_id: T, batch: T, true_edge_index: T, pt: T, **kwargs
    ) -> dict[str, T]:
        edge_index = self._build_graph(
            x=x, batch=batch, true_edge_index=true_edge_index, pt=pt
        )
        attr, rep = _old_hinge_loss_components(
            x=x,
            edge_index=edge_index,
            particle_id=particle_id,
            r_emb_hinge=self.hparams.r_emb,
            pt=pt,
            pt_thld=self.hparams.attr_pt_thld,
            p_attr=self.hparams.p_attr,
            p_rep=self.hparams.p_rep,
        )
        losses = {
            "attr": attr,
            "rep": rep,
        }
        return MultiLossFctReturn(
            loss_dct=losses,
            weight_dct={"attr": 1, "rep": 1},
            extra_metrics={},
        )

In [21]:
true_edge_mask = td1.pt[td1.true_edge_index[0]] > 0.9

In [107]:
near_edges = radius_graph(
    td1.x,
    r=1,
    batch=td1.batch,
    loop=False,
    max_num_neighbors=256,
)

In [108]:
near_edges.shape

torch.Size([2, 2116])

In [109]:
td1.true_edge_index.shape

torch.Size([2, 227])

In [115]:
torch.sort(torch.tensor([[1, 2, 3], [0, 0, 0]]), dim=0).values

tensor([[0, 0, 0],
        [1, 2, 3]])

In [110]:
((td1.particle_id[near_edges[0]] == td1.particle_id[near_edges[1]]) & (td1.particle_id[near_edges[0]] > 0)).sum()

tensor(396)

In [48]:
td1.true_edge_index.shape

torch.Size([2, 227])

In [117]:
ei = torch.unique(torch.sort(torch.cat([td1.true_edge_index[:, true_edge_mask], near_edges], dim=-1), dim=0).values, dim=-1)

In [120]:
true_edge = (td1.particle_id[ei[0]] == td1.particle_id[ei[1]]) & (td1.particle_id[ei[0]] > 0)

In [121]:
true_edge.sum()

tensor(227)

In [40]:
td1.true_edge_index.shape

torch.Size([2, 227])

In [47]:
torch.unique(torch.tensor([
    [0, 2],
    [1, 1]
]), dim=-1)

tensor([[0, 2],
        [1, 1]])

In [102]:
random_pids = np.random.randint(0, 3, 1000)
get_truth_edge_index(random_pids).shape

(2, 103891)

In [106]:
total = 0
for i in np.unique(random_pids):
    if i == 0:
        continue
    count = (random_pids == i).sum()
    total += count * (count-1) // 2
total 

103891

In [None]:
def get_truth_edge_index(pids: A) -> A:
    """Get edge index for all edges, connecting hits of the same `particle_id`.
    To save space, only edges in one direction are returned.
    """
    upids = np.unique(pids[pids > 0])
    mask: A = pids.reshape(1, -1) == upids.reshape(-1, 1)  # type: ignore
    edges = []
    for i_particle in range(mask.shape[0]):
        indices = np.nonzero(mask[i_particle])[0]
        if len(indices) < 2:
            continue
        edges += list(itertools.combinations(indices, 2))
    return np.array(edges).T