## Experimentation with GravNet_Conv

In [71]:
import torch_geometric
import torch
import os
import tarfile
import math

### 1. Obtain data for training

In [72]:
raw_data_dir = "../raw_data"
output_dir = "../data"

In [73]:
for filename in os.listdir(raw_data_dir):
    if filename.endswith(".tar.gz"):
        tar_gz_path = os.path.join(raw_data_dir, filename)
        tar = tarfile.open(tar_gz_path, "r:gz")
        tar.extractall(output_dir)
        tar.close()

In [74]:
raw_data = []

for foldername in os.listdir(output_dir):
    folder = os.path.join(output_dir, foldername)
    for filename in os.listdir(folder):
        if filename.endswith(".pt"):
            file = os.path.join(folder, filename)
            raw_data.append(torch.load(file))

### 2. Observe point cloud properties

In [75]:
print("Length of dataset:", len(raw_data))
print(raw_data[0].x.size())
print(raw_data[0].particle_id.size())
print(torch.unique(raw_data[0].particle_id).size())

Length of dataset: 900
torch.Size([7285, 7])
torch.Size([7285])
torch.Size([1112])


### 3. Loss function initializatinon

`strength_loss` aims to maximize condensation strength of condensation points, minimize condensation strength of non-condensation points.

$\frac{1}{|K|}\sum_{i \in K}(1 - \beta_{i})^2 + \frac{\epsilon}{|N| - |K|}\sum_{i \in (N/K)}(\beta_{i})^2$

In [182]:
def strength_loss(I, A, epsilon=1):
    unique_ids = torch.unique(I)
    max_weights = torch.empty((unique_ids.shape[0],), dtype=A.dtype).to(A.device)
    other_weights = torch.empty((0,), dtype=A.dtype).to(A.device)

    for idx, uid in enumerate(unique_ids):
        mask = I == uid
        weights = A[mask]

        max_weight = torch.max(weights)
        max_weights[idx] = max_weight

        non_max_mask = weights != max_weight
        non_max_weights = torch.masked_select(weights, non_max_mask)

        other_weights = torch.cat([other_weights, non_max_weights])

    loss_a = ((1 - max_weights) ** 2).mean()
    loss_b = ((other_weights)**2).mean()
    return (torch.tensor([4 / (1 + epsilon)]) * (loss_a + epsilon*loss_b), loss_a.item(), loss_b.item())


`potential_loss` aims minimize pairwise distance between hits and their respective condensation points while maintaining distance of at least 1 with other condensation points.

$\frac{1}{|N||K|}\sum_{j=1}^N\sum_{k=1}^K (max(1 - ||x_{\alpha k} - x_j||, 0))^2 + \frac{\epsilon'}{|N|}\sum_{j=1}^N ||x_{\alpha j} - x_j||^2$

In [200]:

def gather_max_elements(ID, weight, position):
    """Find the item with the largest weight for each unique ID."""
    unique_ids = torch.unique(ID)
    max_positions = []
    max_weights = []
    for uid in unique_ids:
        indices = (ID == uid)
        weights = weight[indices]
        positions = position[indices]
        max_idx = torch.argmax(weights)
        max_positions.append(positions[max_idx])
        max_weights.append(weight[max_idx])
    return torch.stack(max_positions), torch.stack(max_weights), unique_ids

def pairwise_distance(a, b):
    """Calculate pairwise distances between vectors (positions) in two different batches."""
    diff = a.unsqueeze(1) - b.unsqueeze(0)
    dist = torch.sqrt(torch.sum(diff**2, dim=-1) + 1e-8) # add a small number to avoid numerical instability
    return dist

def find_condensation_positions(ID_1, Pos_1, ID_2):
    """Find the representative position in Pos_1 for each ID in ID_2."""
    mapping = {id.item(): pos for id, pos in zip(ID_1, Pos_1)}
    return torch.stack([mapping[id.item()] for id in ID_2])

def potential_loss(ID, weight, position):
    con_position, _, con_ids = gather_max_elements(ID, weight, position)

    pairwise_dist = pairwise_distance(con_position, position)
    dist_transformed = (torch.clamp(1 - pairwise_dist, min=0)) ** 2
    repulsive_loss =  torch.mean(dist_transformed)

    condensation_positions = find_condensation_positions(con_ids, con_position, ID)
    
    # Compute squared Euclidean distance
    diff = position - condensation_positions
    attractive_loss = torch.mean(torch.sum(diff**2, dim=-1))
    
    # Average the squared distances
    return 2 * attractive_loss * repulsive_loss


#### 4. Training Loop

In [185]:
gravnet = torch_geometric.nn.conv.GravNetConv(in_channels=7, out_channels=3, space_dimensions=3, propagate_dimensions=8, k=16)

In [186]:
optimizer = torch.optim.Adam(gravnet.parameters(), lr=0.001)

In [187]:
normalization_factor = torch.abs(gravnet(raw_data[0].x)).mean().item()
print(normalization_factor)

149.3076934814453


In [188]:
data_index = 0
while True:
    optimizer.zero_grad()
    y = torch.sigmoid(gravnet(raw_data[data_index].x) / normalization_factor)
    loss, a, b = strength_loss(raw_data[data_index].particle_id, y[:, 0])
    print(loss.item(), a, b)
    loss.backward()
    optimizer.step()
    data_index += 1
    if(data_index >= len(raw_data)):
        data_index = 0

1.1970282793045044 0.2602432072162628 0.3382709324359894
1.1771377325057983 0.25587087869644165 0.3326979875564575
1.1712192296981812 0.26385027170181274 0.32175934314727783
1.1701385974884033 0.2818466126918793 0.3032227158546448
1.1727714538574219 0.263521283864975 0.32286444306373596
1.19313645362854 0.2576221227645874 0.3389461040496826
1.1694176197052002 0.2880071699619293 0.29670166969299316
1.1674909591674805 0.2570866644382477 0.32665884494781494
1.1454507112503052 0.2643033266067505 0.3084220290184021
1.1633687019348145 0.27595633268356323 0.3057280480861664
1.1562926769256592 0.25824806094169617 0.31989824771881104
1.1783478260040283 0.2503981292247772 0.33877578377723694
1.168221354484558 0.24673490226268768 0.33737578988075256
1.120903730392456 0.25511103868484497 0.30534079670906067
1.127129077911377 0.252450555562973 0.31111401319503784
1.1461725234985352 0.2517152428627014 0.32137104868888855
1.1558566093444824 0.25448405742645264 0.32344427704811096
1.152267575263977 0.

KeyboardInterrupt: 