# Fix radius graph loss function

* **Description**: Backpropagating from Jian's loss function reslts in NaNs in weights. Why?
* **Status**:  Active
* **Preceeded by**:
* **Succeeded by**:
* **See also**: 


In [1]:
import sys
from gnn_tracking.metrics.losses import _first_occurrences, _square_distances
import torch

from torch import Tensor as T
from torch import nn
from torch_cluster import radius_graph

from gnn_tracking.utils.log import logger

from pathlib import Path



In [2]:
repo_path = Path("/home/kl5675/Documents/23/git_sync/gnn_tracking/tests")
assert repo_path.is_dir()
sys.path.append(str(repo_path))


In [3]:

from test_losses import generate_test_data

td1 = generate_test_data()

In [4]:
x_orig = torch.rand_like(td1.x, dtype=torch.float)
x_orig.shape, x_orig.dtype

(torch.Size([1000, 3]), torch.float32)

In [5]:
from gnn_tracking.models.mlp import MLP

toy_module = MLP(
    input_size=3,
    output_size=3,
    hidden_dim=3,
    L=3,
    bias=False
)
x_orig.shape

torch.Size([1000, 3])

In [6]:
optimizer = torch.optim.Adam(toy_module.parameters(), lr=0.00001)


In [7]:
beta=td1.beta
particle_id=td1.particle_id
q_min=0.01
mask=torch.ones_like(td1.beta, dtype=bool)
radius_threshold=1
max_num_neighbors=500000

In [8]:
list(toy_module.parameters())

[Parameter containing:
 tensor([[-0.0611, -0.2563,  0.5765],
         [-0.2652, -0.4793,  0.1608],
         [ 0.5089, -0.5428, -0.5499]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.2008, -0.4391, -0.5297],
         [ 0.2972,  0.4742,  0.0300],
         [ 0.1287,  0.1871,  0.4061]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.0476, -0.5268,  0.4689],
         [ 0.2731, -0.3157,  0.0427],
         [-0.0703,  0.4760, -0.0117]], requires_grad=True)]

In [9]:
optimizer.zero_grad()
x = toy_module(x_orig)


In [10]:
# -- 1. Determine indices of condensation points (CPs) and q --
_sorted_indices = torch.argsort(beta, descending=True)
_pids_sorted = particle_id[_sorted_indices]
_alphas = _sorted_indices[_first_occurrences(_pids_sorted)]
# Index of condensation points in node array
alphas = _alphas[particle_id[_alphas] > 0]
assert alphas.size()[0] > 0, "No particles found, cannot evaluate loss"
q = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q).any(), "q contains NaNs"

# -- 2. Edges for repulsion loss --
_radius_edges = radius_graph(
    x=x, r=radius_threshold, max_num_neighbors=max_num_neighbors, loop=False
)
# Now filter out everything that doesn't include a CP or connects two hits of the
# same particle
_to_cp = torch.isin(_radius_edges[0], alphas)
_is_repulsive = particle_id[_radius_edges[0]] != particle_id[_radius_edges[1]]
_is_repulsive.shape

torch.Size([999000])

In [11]:
# repulsion_edges = _radius_edges[:, _is_repulsive & _to_cp]

In [12]:
def _square_distances(edges: T, positions: T) -> T:
    """Returns squared distances between two sets of points"""
    return torch.sum((positions[edges[0]] - positions[edges[1]]) ** 2, dim=-1)

In [13]:

# -- 3. Edges for attractive loss --
# 1D array (n_nodes): 1 for CPs, 0 otherwise
# alpha_hits_filter = torch.zeros(
#     len(particle_id), dtype=bool, device=x.device
# ).scatter_(0, alphas, 1)
# # indices of all non-CPs
# non_alpha_indices = torch.arange(len(particle_id), device=x.device)[
#     ~alpha_hits_filter
# ]

# for each non-CP hit, the index of the corresponding CP
# alpha_indices = _alphas[
#     torch.searchsorted(particle_id[_alphas], particle_id[non_alpha_indices])
# ]

# Insert alpha indices into their respective positions to form attraction edges
# unmasked_attraction_edges = (
    # torch.arange(len(particle_id), device=x.device).unsqueeze(0).repeat(2, 1)
# )
# unmasked_attraction_edges[1, ~alpha_hits_filter] = alpha_indices

# Apply mask to attraction edges
# attraction_edges = unmasked_attraction_edges[:, mask]

# -- 4. Calculate loss --
# repulsion_distances = radius_threshold - torch.sqrt(
#     _square_distances(_radius_edges, x)
# )


In [14]:

# attraction_distances = _square_distances(attraction_edges, x)

# va = attraction_distances * q[attraction_edges[0]] * q[attraction_edges[1]]
# vr = repulsion_distances * q[repulsion_edges[0]] * q[repulsion_edges[1]]

# assert not torch.isnan(vr).any()
# # if torch.isnan(vr).any():
#     vr = torch.tensor([[0.0]])
#     logger.warning("Repulsive loss is NaN")

# a = {
#     "attractive": (1 / mask.sum()) * torch.sum(va),
#     "repulsive": (1 / x.size()[0]) * torch.sum(vr),
# }
loss = torch.sqrt(_square_distances(_radius_edges, x)).mean()# + va.sum()
# loss = a["attractive"] + 10*a["repulsive"]

In [15]:
loss

tensor(0.0329, grad_fn=<MeanBackward0>)

In [16]:
loss.backward()

In [17]:
optimizer.step()

In [18]:
list(toy_module.parameters())

[Parameter containing:
 tensor([[-0.0611, -0.2563,  0.5765],
         [-0.2652, -0.4793,  0.1608],
         [ 0.5089, -0.5428, -0.5499]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.2008, -0.4391, -0.5296],
         [ 0.2972,  0.4742,  0.0300],
         [ 0.1287,  0.1871,  0.4061]], requires_grad=True),
 Parameter containing:
 tensor([[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]], requires_grad=True)]

## Simplified

In [271]:
td1 = generate_test_data()

In [272]:
x_orig = torch.rand_like(td1.x, dtype=torch.float)
x_orig.shape, x_orig.dtype

(torch.Size([1000, 3]), torch.float32)

In [320]:
from gnn_tracking.models.mlp import MLP

toy_module = MLP(
    input_size=3,
    output_size=3,
    hidden_dim=3,
    L=3,
    bias=False
)
x_orig.shape

torch.Size([1000, 3])

In [321]:
optimizer = torch.optim.Adam(toy_module.parameters(), lr=0.00001)


In [322]:
mask=torch.ones_like(td1.beta, dtype=bool)
radius_threshold=1
max_num_neighbors=500000

In [328]:
list(toy_module.parameters())

[Parameter containing:
 tensor([[ 0.5678, -0.4374, -0.4212],
         [ 0.4101, -0.3442,  0.3485],
         [ 0.0829, -0.5078, -0.3383]], requires_grad=True),
 Parameter containing:
 tensor([[-0.4134, -0.0649,  0.1921],
         [-0.4560, -0.1656,  0.3097],
         [-0.0355, -0.4623, -0.3310]], requires_grad=True),
 Parameter containing:
 tensor([[-0.4305, -0.1754,  0.5266],
         [ 0.1424, -0.1132,  0.0502],
         [-0.4116,  0.1427,  0.1679]], requires_grad=True)]

In [None]:
x_orig = torch.rand_like(td1.x, dtype=torch.float)
x_orig.shape, x_orig.dtype

(torch.Size([1000, 3]), torch.float32)

In [None]:
from gnn_tracking.models.mlp import MLP

toy_module = MLP(
    input_size=3,
    output_size=3,
    hidden_dim=3,
    L=3,
    bias=False
)
x_orig.shape

torch.Size([1000, 3])

In [None]:
optimizer = torch.optim.Adam(toy_module.parameters(), lr=0.00001)


In [None]:
beta=td1.beta
particle_id=td1.particle_id
q_min=0.01
mask=torch.ones_like(td1.beta, dtype=bool)
radius_threshold=1
max_num_neighbors=500000

In [None]:
list(toy_module.parameters())

[Parameter containing:
 tensor([[-0.0611, -0.2563,  0.5765],
         [-0.2652, -0.4793,  0.1608],
         [ 0.5089, -0.5428, -0.5499]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.2008, -0.4391, -0.5297],
         [ 0.2972,  0.4742,  0.0300],
         [ 0.1287,  0.1871,  0.4061]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.0476, -0.5268,  0.4689],
         [ 0.2731, -0.3157,  0.0427],
         [-0.0703,  0.4760, -0.0117]], requires_grad=True)]

In [None]:
optimizer.zero_grad()
x = toy_module(x_orig)


In [None]:
# -- 1. Determine indices of condensation points (CPs) and q --
_sorted_indices = torch.argsort(beta, descending=True)
_pids_sorted = particle_id[_sorted_indices]
_alphas = _sorted_indices[_first_occurrences(_pids_sorted)]
# Index of condensation points in node array
alphas = _alphas[particle_id[_alphas] > 0]
assert alphas.size()[0] > 0, "No particles found, cannot evaluate loss"
q = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q).any(), "q contains NaNs"

# -- 2. Edges for repulsion loss --
_radius_edges = radius_graph(
    x=x, r=radius_threshold, max_num_neighbors=max_num_neighbors, loop=False
)
# Now filter out everything that doesn't include a CP or connects two hits of the
# same particle
_to_cp = torch.isin(_radius_edges[0], alphas)
_is_repulsive = particle_id[_radius_edges[0]] != particle_id[_radius_edges[1]]
_is_repulsive.shape

torch.Size([999000])

In [None]:
# repulsion_edges = _radius_edges[:, _is_repulsive & _to_cp]

In [None]:
def _square_distances(edges: T, positions: T) -> T:
    """Returns squared distances between two sets of points"""
    return torch.sum((positions[edges[0]] - positions[edges[1]]) ** 2, dim=-1)

In [None]:

# -- 3. Edges for attractive loss --
# 1D array (n_nodes): 1 for CPs, 0 otherwise
# alpha_hits_filter = torch.zeros(
#     len(particle_id), dtype=bool, device=x.device
# ).scatter_(0, alphas, 1)
# # indices of all non-CPs
# non_alpha_indices = torch.arange(len(particle_id), device=x.device)[
#     ~alpha_hits_filter
# ]

# for each non-CP hit, the index of the corresponding CP
# alpha_indices = _alphas[
#     torch.searchsorted(particle_id[_alphas], particle_id[non_alpha_indices])
# ]

# Insert alpha indices into their respective positions to form attraction edges
# unmasked_attraction_edges = (
    # torch.arange(len(particle_id), device=x.device).unsqueeze(0).repeat(2, 1)
# )
# unmasked_attraction_edges[1, ~alpha_hits_filter] = alpha_indices

# Apply mask to attraction edges
# attraction_edges = unmasked_attraction_edges[:, mask]

# -- 4. Calculate loss --
# repulsion_distances = radius_threshold - torch.sqrt(
#     _square_distances(_radius_edges, x)
# )


In [None]:

# attraction_distances = _square_distances(attraction_edges, x)

# va = attraction_distances * q[attraction_edges[0]] * q[attraction_edges[1]]
# vr = repulsion_distances * q[repulsion_edges[0]] * q[repulsion_edges[1]]

# assert not torch.isnan(vr).any()
# # if torch.isnan(vr).any():
#     vr = torch.tensor([[0.0]])
#     logger.warning("Repulsive loss is NaN")

# a = {
#     "attractive": (1 / mask.sum()) * torch.sum(va),
#     "repulsive": (1 / x.size()[0]) * torch.sum(vr),
# }
loss = torch.sqrt(_square_distances(_radius_edges, x)).mean()# + va.sum()
# loss = a["attractive"] + 10*a["repulsive"]

In [None]:
loss

tensor(0.0329, grad_fn=<MeanBackward0>)

In [None]:
loss.backward()

In [None]:
optimizer.step()

In [None]:
list(toy_module.parameters())

[Parameter containing:
 tensor([[-0.0611, -0.2563,  0.5765],
         [-0.2652, -0.4793,  0.1608],
         [ 0.5089, -0.5428, -0.5499]], requires_grad=True),
 Parameter containing:
 tensor([[ 0.2008, -0.4391, -0.5296],
         [ 0.2972,  0.4742,  0.0300],
         [ 0.1287,  0.1871,  0.4061]], requires_grad=True),
 Parameter containing:
 tensor([[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]], requires_grad=True)]

In [329]:
optimizer.zero_grad()
x = toy_module(x_orig)

In [325]:
_radius_edges = radius_graph(
    x=x, r=radius_threshold, max_num_neighbors=max_num_neighbors, loop=False
)

In [326]:
(_radius_edges[0] - _radius_edges[1]).abs().min()

tensor(1)

In [330]:
x

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        ...,
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<MmBackward0>)

In [314]:
(x[_radius_edges[0]] - x[_radius_edges[1]]).abs().min()

tensor(0., grad_fn=<MinBackward1>)

In [315]:
eps = 0
loss = torch.sqrt(eps + torch.sum((x[_radius_edges[0]] - x[_radius_edges[1]]) ** 2, dim=-1)).mean()# .mean()# + va.sum()
loss

tensor(0.0068, grad_fn=<MeanBackward0>)

In [316]:
loss.backward()
optimizer.step()
list(toy_module.parameters())

[Parameter containing:
 tensor([[-0.3646, -0.2950, -0.5071],
         [-0.3415, -0.2405,  0.0264],
         [-0.5583,  0.0593,  0.2373]], requires_grad=True),
 Parameter containing:
 tensor([[-0.5149,  0.3048, -0.5356],
         [-0.1742,  0.0945,  0.3880],
         [ 0.2685, -0.2282,  0.4230]], requires_grad=True),
 Parameter containing:
 tensor([[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]], requires_grad=True)]