In [1]:

from gnn_tracking.models.mlp import MLP
from torch import Tensor
from torch import nn


class MetricLearningGraphConstruction(nn.Module):
    def __init__(self, *, node_indim: int, outdim: int = 12, n_layers: int, layer_width: int):
        super().__init__()
        self.encoder = MLP(node_indim, layer_width, layer_width, L=n_layers, include_last_activation=True)
        self.beta_nn = MLP(layer_width, 1, layer_width, L=1)
        self.latent = MLP(layer_width, outdim, layer_width, L=1, include_last_activation=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, data) -> dict[str, Tensor]:
        h = self.encoder(data.x)
        r = {
            "H": self.latent(h),
            "B": self.sigmoid(self.beta_nn(h)).squeeze(),
        }
        return r

%load_ext autoreload
%autoreload 2


In [2]:
from gnn_tracking.utils.loading import TrackingDataset, get_loaders

ds = TrackingDataset(
    [
        f"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v2/part_{i}"
        for i in range(1, 9)
    ]
)
val_ds = TrackingDataset(
    "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v2/part_9"
)
loaders = get_loaders({"train": ds, "val": val_ds}, batch_size=1, max_sample_size=10_000)

[32m[18:33:42] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[18:33:42] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v2/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v2/part_8/data28999_s0.pt[0m
[32m[18:33:42] INFO: DataLoader will load 1000 graphs (out of 1000 available).[0m
[36m[18:33:42] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v2/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v2/part_9/data29999_s0.pt[0m
[36m[18:33:42] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x146e6f1fcca0>, 'pin_memory': True, 'shuffle': None}[0m
[36m[18:33:42] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True

In [3]:

from functools import partial
from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan, dbscan
from gnn_tracking.metrics.losses import PotentialLoss, BackgroundLoss

loss_functions = {
    "potential": (PotentialLoss(q_min=0.01, radius_threshold=3), {"attractive": 1, "repulsive": 1e-3}),
    "background": (BackgroundLoss(sb=1), 0.05),
}



In [4]:

from typing import Any
from gnn_tracking.metrics.cluster_metrics import common_metrics
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner
import numpy as np
from gnn_tracking.postprocessing.clusterscanner import ClusterScanResult


def simple_scan(
    graphs: list[np.ndarray],
    truth: list[np.ndarray],
    sectors: list[np.ndarray],
    pts: list[np.ndarray],
    reconstructable: list[np.ndarray],
    start_params= None,
    node_mask= None,
    epoch=None,
) -> ClusterScanResult:
    if start_params is None:
        start_params = {
            "eps": 0.95,
            "min_samples": 1,
        }
    dbss = DBSCANHyperParamScanner(
        data=graphs,
        truth=truth,
        sectors=sectors,
        pts=pts,
        reconstructable=reconstructable,
        guide="adjusted_rand",
        metrics={"adjusted_rand": common_metrics["adjusted_rand"]},
        node_mask=node_mask,
        min_samples_range=(1, 1),
        eps_range=(0.95, 1.25),
    )
    return dbss.scan(
        n_jobs=3,
        n_trials=3,
        start_params=start_params,
    )

cfs = {
    "dbscan": simple_scan
}

In [5]:
from gnn_tracking.training.tcn_trainer import TCNTrainer

trainer = TCNTrainer(MetricLearningGraphConstruction(node_indim=7, n_layers=6, layer_width=256), loaders, loss_functions, lr=1e-3, cluster_functions=cfs)

[32m[18:33:43 TCNTrainer] INFO: Using device cuda[0m


In [6]:
trainer.load_checkpoint("230526_181642_model.pt")

In [None]:
trainer.train()

  storage = elem.storage()._new_shared(numel)
[36m[18:42:16 TCNTrainer] DEBUG: Epoch 3 (    0/10000): Total=   0.03912, potential_attractive=   0.00218, potential_repulsive=   0.00193, background=   0.03501 (weighted)[0m
[36m[18:42:18 TCNTrainer] DEBUG: Epoch 3 (   10/10000): Total=   0.04028, potential_attractive=   0.00194, potential_repulsive=   0.00117, background=   0.03717 (weighted)[0m
[36m[18:42:19 TCNTrainer] DEBUG: Epoch 3 (   20/10000): Total=   0.04027, potential_attractive=   0.00215, potential_repulsive=   0.00266, background=   0.03546 (weighted)[0m
[36m[18:42:21 TCNTrainer] DEBUG: Epoch 3 (   30/10000): Total=   0.03891, potential_attractive=   0.00180, potential_repulsive=   0.00197, background=   0.03514 (weighted)[0m
[36m[18:42:22 TCNTrainer] DEBUG: Epoch 3 (   40/10000): Total=   0.03948, potential_attractive=   0.00186, potential_repulsive=   0.00221, background=   0.03541 (weighted)[0m
[36m[18:42:24 TCNTrainer] DEBUG: Epoch 3 (   50/10000): Total=   0.0

In [7]:
data = trainer.val_loader.dataset[0]

In [8]:
mo = trainer.evaluate_model(data)

In [17]:
labels = dbscan(mo["x"].detach().cpu().numpy())

In [18]:
from sklearn import metrics


In [23]:
metrics.adjusted_rand_score(data.particle_id.detach().cpu().numpy(), labels)

0.0

In [9]:
from torch_cluster import radius_graph

In [10]:
edge_index = radius_graph(mo["x"], 1)

In [12]:
edge_index.shape

torch.Size([2, 1936826])

In [11]:
from torch_geometric.data import Data

data = Data(x=mo["x"], edge_index=edge_index)

In [12]:
data.y = mo["particle_id"][edge_index[0,:]] == mo["particle_id"][edge_index[1,:]]
data.pt = mo["particle_id"]
data.particle_id = mo["particle_id"]

In [13]:
mo["x"].device

device(type='cuda', index=0)

In [14]:
data.pt.device, data.particle_id.device

(device(type='cuda', index=0), device(type='cuda', index=0))

In [16]:
%aimport gnn_tracking.analysis.graphs
from gnn_tracking.analysis.graphs import get_all_graph_construction_stats

r = get_all_graph_construction_stats(data.to("cuda"))

ValueError: max() arg is an empty sequence

In [None]:
r

In [17]:
from gnn_tracking.utils.graph_masks import get_edge_mask_from_node_mask
from gnn_tracking.metrics.binary_classification import BinaryClassificationStats
import torch

bcs = BinaryClassificationStats(
    output=torch.ones_like(data.y).long(), y=data.y.long(), thld=0.0
)
pt_edge_mask = get_edge_mask_from_node_mask(data.pt > 0.9, data.edge_index)
bcs_thld = BinaryClassificationStats(
    output=torch.ones_like(data.y[pt_edge_mask]).long(),
    y=data.y[pt_edge_mask].long(),
    thld=0.0,
)

In [18]:
bcs.get_all()

{'acc': 0.02932068620213888,
 'TPR': 1.0,
 'TNR': 0.0,
 'FPR': 1.0,
 'FNR': 0.0,
 'balanced_acc': 0.5,
 'F1': 0.05697094519750254,
 'MCC': 0,
 'n_true': 57038,
 'n_false': 1888278,
 'n_predicted_true': 1945316,
 'n_predicted_false': 0}

In [19]:
57038 / (57038 + 1888278)

0.02932068620213888