In [1]:
from gnn_tracking.preprocessing.point_cloud_builder import PointCloudBuilder
import torch_geometric
import torch
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)

In [3]:
# unpack data

import os
import tarfile

raw_data_dir = "../raw_data"
output_dir = "../data"

for filename in os.listdir(raw_data_dir):
    if filename.endswith(".tar.gz"):
        tar_gz_path = os.path.join(raw_data_dir, filename)
        tar = tarfile.open(tar_gz_path, "r:gz")
        tar.extractall(output_dir)
        tar.close()


In [12]:
# Load data

raw_data = []
for filename in os.listdir(output_dir):
    directoryname = os.path.join(output_dir, filename)
    for data in os.listdir(directoryname):
        dataname = os.path.join(directoryname, data)
        raw_data.append(torch.load(dataname))


In [19]:
print(len(raw_data))
print(raw_data[0].x.size())
print(raw_data[0].particle_id.size())
print(torch.unique(raw_data[0].particle_id).size())

900
torch.Size([7285, 7])
torch.Size([7285])
torch.Size([1112])


In [60]:
gravnet = torch_geometric.nn.conv.GravNetConv(in_channels=7, out_channels=3, space_dimensions=8, propagate_dimensions=8, k=8)

In [69]:
# loss_function implementation
import math

def pairwise_distance(x):
    """Calculate pairwise distances between vectors (positions) in a batch."""
    diff = x.unsqueeze(2) - x.t().unsqueeze(0)
    dist = torch.sqrt(torch.sum(diff**2, dim=-1) + 1e-8) # add a small number to avoid numerical instability
    return dist

def gather_max_elements(ID, weight, position):
    """Find the item with the largest weight for each unique ID."""
    unique_ids = torch.unique(ID)
    max_positions = []
    for uid in unique_ids:
        indices = (ID == uid)
        weights = weight[indices]
        positions = position[indices]
        max_idx = torch.argmax(weights)
        max_positions.append(positions[max_idx])
    return torch.stack(max_positions), unique_ids

def gather_elements_with_id(ID, uid, position):
    """Gather elements that share the same ID value."""
    indices = (ID == uid)
    positions = position[indices]
    return positions

def loss_function(ID, weight, position):
    max_positions, unique_ids = gather_max_elements(ID, weight, position)
    
    mean_distances = []
    for i, uid in enumerate(unique_ids):
        positions = gather_elements_with_id(ID, uid, position)
        diff = positions - max_positions[i]
        dist = torch.sum(diff**2, dim=-1) + 1e-8
        mean_dist = torch.mean(dist)
        mean_distances.append(mean_dist)
    A = torch.mean(torch.stack(mean_distances))
    
    pairwise_dist = pairwise_distance(max_positions)
    print(pairwise_dist)
    B = torch.mean((1 - pairwise_dist) ** 2)

    return 100*A + B

In [70]:
y = torch.sigmoid(gravnet(raw_data[2].x)/300)

loss_function(raw_data[2].particle_id, y[:, 0], y[:, 1:])

tensor([[ 5.0511,  2.7903],
        [ 5.2653,  2.7800],
        [ 4.9359,  3.1573],
        ...,
        [ 7.1659,  3.6903],
        [10.0643,  5.5167],
        [11.2736,  6.8508]], grad_fn=<SqrtBackward0>)


tensor(22.6154, grad_fn=<AddBackward0>)