Skip to content

Commit

Permalink
Changes for noise-classification, dump
Browse files Browse the repository at this point in the history
  • Loading branch information
Aryaman Jeendgar committed Mar 21, 2024
1 parent bdd7ae7 commit b4df3c8
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/gnn_tracking/metrics/losses/classification.py
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from torch.nn import CrossEntropyLoss

class CEL(nn.Module, HyperparametersMixin):
def __init__(
self,
weight
):
super().__init__()
self._loss_fct = CrossEntropyLoss(weight=torch.tensor(weight))
self.weight = weight
self.save_hyperparameters()

def forward(
self,
input: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
return self._loss_fct(input, target)
38 changes: 38 additions & 0 deletions src/gnn_tracking/metrics/noise_classification.py
@@ -0,0 +1,38 @@
import torch.nn as nn
import torchmetrics as mm
from torch_geometric.data import Data

from gnn_tracking.utils.graph_masks import get_good_node_mask

def get_fp_pt(
data: Data, model: nn.Module, pt_thld = 0.9, max_eta: float = 4.0
) -> dict[str, float]:
hit_mask = get_good_node_mask(data, pt_thld=pt_thld, max_eta=max_eta)
data = data.subgraph(hit_mask)
roc = mm.ROC(task='binary')
aucroc = mm.AUROC(task='binary')
preds = model(data)["H"][:, 0]
# fpr, _, _ = roc(preds, data.particle_id == 0)
return aucroc(preds, data.particle_id == 0)
# return fpr

def get_efficiency_purity_edges(
data: Data, pt_thld: float = 0.9, max_eta: float = 4.0
) -> dict[str, float]:
"""Calculate efficiency and purity for edges based on ``data.true_edge_index``.
Only edges where at least one of the two nodes is accepted by the pt threshold
(and reconstructable etc.) are considered.
"""
hit_mask = get_good_node_mask(data, pt_thld=pt_thld, max_eta=max_eta)
edge_mask = hit_mask[data.edge_index[0]] | hit_mask[data.edge_index[1]]
true_edge_mask = (
hit_mask[data.true_edge_index[0]] & hit_mask[data.true_edge_index[1]]
)
# Factor of 2 because the true edges are undirected
efficiency = data.y[edge_mask].sum() / (2 * true_edge_mask.sum())
purity = data.y[edge_mask].sum() / edge_mask.sum()
return {
"efficiency": efficiency.item(),
"purity": purity.item(),
}
5 changes: 5 additions & 0 deletions src/gnn_tracking/models/graph_construction.py
Expand Up @@ -12,6 +12,7 @@
from torch import nn
from torch.jit import script as jit
from torch_cluster import knn_graph
from torch.nn.functional import softmax
from torch_geometric.data import Data

from gnn_tracking.models.mlp import MLP, HeterogeneousResFCNN, ResFCNN
Expand All @@ -31,6 +32,7 @@ def __init__(
out_dim: int,
depth: int,
alpha: float = 0.6,
classification: bool = False
):
"""Fully connected neural network for graph construction.
Contains additional normalization parameter for the latent space.
Expand All @@ -46,10 +48,13 @@ def __init__(
self._latent_normalization = torch.nn.Parameter(
torch.tensor([1.0]), requires_grad=True
)
self.classification = classification
self.save_hyperparameters()

def forward(self, data: Data) -> dict[str, T]:
out = super().forward(data.x) * self._latent_normalization
if self.classification:
out = softmax(out)
return {"H": out}


Expand Down
84 changes: 84 additions & 0 deletions src/gnn_tracking/training/classification.py
@@ -0,0 +1,84 @@
import torch
# import torchmetrics
from typing import Any

from torch import Tensor
from torch import Tensor as T
from torch_geometric.data import Data

from torch.nn import CrossEntropyLoss, BCELoss
from gnn_tracking.metrics.noise_classification import get_fp_pt
from gnn_tracking.training.base import TrackingModule
from gnn_tracking.utils.dictionaries import add_key_suffix, to_floats
from gnn_tracking.utils.lightning import obj_from_or_to_hparams
from gnn_tracking.utils.oom import tolerate_some_oom_errors


class NodeClassifierModule(TrackingModule):
# noinspection PyUnusedLocal
def __init__(
self,
*,
loss_fct: torch.nn,
**kwargs,
):
"""Pytorch lightning module with training and validation step for the metric
learning approach to graph construction.
"""
super().__init__(**kwargs)
self.loss_fct: BCELoss | CrossEntropyLoss = obj_from_or_to_hparams(
self, "loss_fct", loss_fct
)
# self.valid_acc = torchmetrics.Accuracy(task="binary")
# self.roc = torchmetrics.ROC(task='binary')

# noinspection PyUnusedLocal
def get_losses(self, out: dict[str, Any], data: Data) -> tuple[T, dict[str, float]]:
# targets = torch.vstack([data.particle_id == 0, data.particle_id != 0]).type(torch.LongTensor).to('cuda')
targets = torch.tensor(list(zip(data.particle_id == 0, data.particle_id != 0))).type(torch.FloatTensor).to('cuda')
loss = self.loss_fct(out["H"], targets)
metrics = {}
metrics["total"] = float(loss)
return loss, metrics

@tolerate_some_oom_errors
def training_step(self, batch: Data, batch_idx: int) -> Tensor | None:
batch = self.data_preproc(batch)
out = self(batch, _preprocessed=True)
loss, loss_dct = self.get_losses(out, batch)
self.log_dict(
add_key_suffix(loss_dct, "_train"),
prog_bar=True,
on_step=True,
batch_size=self.trainer.train_dataloader.batch_size,
)
return loss

def validation_step(self, batch: Data, batch_idx: int):
batch = self.data_preproc(batch)
out = self(batch, _preprocessed=True)
loss, loss_dct = self.get_losses(out, batch)
self.log_dict(
add_key_suffix(loss_dct, "_val"),
prog_bar=True,
on_step=True,
batch_size=self.trainer.val_dataloaders.batch_size
)
metrics = {}
metrics["fp_pt"] = get_fp_pt(batch, self.model)
self.log_dict(
metrics,
batch_size=self.trainer.val_dataloaders.batch_size
)

def on_validation_epoch_end(self) -> None:
pass

def highlight_metric(self, metric: str) -> bool:
return metric in [
"n_edges_frac_segment50_95",
"total",
"attractive",
"repulsive",
"max_frac_segment50",
]

0 comments on commit b4df3c8

Please sign in to comment.