In [46]:
import sys
from gnn_tracking.metrics.losses import _first_occurrences, _square_distances
import torch

from torch import Tensor as T
from torch_cluster import radius_graph

from gnn_tracking.utils.log import logger


sys.path.append("/home/fuchur/Documents/23/git_sync/gnn_tracking")

from tests.test_losses import generate_test_data

td1 = generate_test_data()

In [47]:

beta=td1.beta
x=td1.x
particle_id=td1.particle_id
q_min=0.01
mask=torch.ones_like(td1.beta, dtype=bool)
radius_threshold=1
max_num_neighbors=500000


In [48]:
# -- 1. Determine indices of condensation points (CPs) and q --
_sorted_indices = torch.argsort(beta, descending=True)
_pids_sorted = particle_id[_sorted_indices]
_alphas = _sorted_indices[_first_occurrences(_pids_sorted)]
# Index of condensation points in node array
alphas = _alphas[particle_id[_alphas] > 0]
assert alphas.size()[0] > 0, "No particles found, cannot evaluate loss"
q = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q).any(), "q contains NaNs"

# -- 2. Edges for repulsion loss --
_radius_edges = radius_graph(
    x=x, r=radius_threshold, max_num_neighbors=max_num_neighbors, loop=False
)
# Now filter out everything that doesn't include a CP or connects two hits of the
# same particle
_to_cp = torch.isin(_radius_edges[0], alphas)
_is_repulsive = particle_id[_radius_edges[0]] != particle_id[_radius_edges[1]]

In [49]:
_is_repulsive.shape

torch.Size([909002])

In [50]:
_to_cp.shape

torch.Size([909002])

In [51]:
repulsion_edges = _radius_edges[:, _is_repulsive & _to_cp]

In [52]:

# -- 3. Edges for attractive loss --
# 1D array (n_nodes): 1 for CPs, 0 otherwise
alpha_hits_filter = torch.zeros(
    len(particle_id), dtype=bool, device=x.device
).scatter_(0, alphas, 1)
# indices of all non-CPs
non_alpha_indices = torch.arange(len(particle_id), device=x.device)[
    ~alpha_hits_filter
]

# for each non-CP hit, the index of the corresponding CP
alpha_indices = _alphas[
    torch.searchsorted(particle_id[_alphas], particle_id[non_alpha_indices])
]

# Insert alpha indices into their respective positions to form attraction edges
unmasked_attraction_edges = (
    torch.arange(len(particle_id), device=x.device).unsqueeze(0).repeat(2, 1)
)
unmasked_attraction_edges[1, ~alpha_hits_filter] = alpha_indices

# Apply mask to attraction edges
attraction_edges = unmasked_attraction_edges[:, mask]

# -- 4. Calculate loss --
repulsion_distances = radius_threshold - torch.sqrt(
    _square_distances(repulsion_edges, x)
)
attraction_distances = _square_distances(attraction_edges, x)

va = attraction_distances * q[attraction_edges[0]] * q[attraction_edges[1]]
vr = repulsion_distances * q[repulsion_edges[0]] * q[repulsion_edges[1]]

if torch.isnan(vr).any():
    vr = torch.tensor([[0.0]])
    logger.warning("Repulsive loss is NaN")

a = {
    "attractive": (1 / mask.sum()) * torch.sum(va),
    "repulsive": (1 / x.size()[0]) * torch.sum(vr),
}
a["attractive"] + 10*a["repulsive"]

tensor(1393.8707, dtype=torch.float64)

In [53]:
from tests.test_losses import get_condensation_loss

get_condensation_loss(td1)

tensor(1393.8616, dtype=torch.float64)