-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changes for noise-classification, dump
- Loading branch information
Aryaman Jeendgar
committed
Mar 21, 2024
1 parent
bdd7ae7
commit b4df3c8
Showing
4 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import torch.nn as nn | ||
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin | ||
from torch.nn import CrossEntropyLoss | ||
|
||
class CEL(nn.Module, HyperparametersMixin): | ||
def __init__( | ||
self, | ||
weight | ||
): | ||
super().__init__() | ||
self._loss_fct = CrossEntropyLoss(weight=torch.tensor(weight)) | ||
self.weight = weight | ||
self.save_hyperparameters() | ||
|
||
def forward( | ||
self, | ||
input: torch.Tensor, | ||
target: torch.Tensor | ||
) -> torch.Tensor: | ||
return self._loss_fct(input, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch.nn as nn | ||
import torchmetrics as mm | ||
from torch_geometric.data import Data | ||
|
||
from gnn_tracking.utils.graph_masks import get_good_node_mask | ||
|
||
def get_fp_pt( | ||
data: Data, model: nn.Module, pt_thld = 0.9, max_eta: float = 4.0 | ||
) -> dict[str, float]: | ||
hit_mask = get_good_node_mask(data, pt_thld=pt_thld, max_eta=max_eta) | ||
data = data.subgraph(hit_mask) | ||
roc = mm.ROC(task='binary') | ||
aucroc = mm.AUROC(task='binary') | ||
preds = model(data)["H"][:, 0] | ||
# fpr, _, _ = roc(preds, data.particle_id == 0) | ||
return aucroc(preds, data.particle_id == 0) | ||
# return fpr | ||
|
||
def get_efficiency_purity_edges( | ||
data: Data, pt_thld: float = 0.9, max_eta: float = 4.0 | ||
) -> dict[str, float]: | ||
"""Calculate efficiency and purity for edges based on ``data.true_edge_index``. | ||
Only edges where at least one of the two nodes is accepted by the pt threshold | ||
(and reconstructable etc.) are considered. | ||
""" | ||
hit_mask = get_good_node_mask(data, pt_thld=pt_thld, max_eta=max_eta) | ||
edge_mask = hit_mask[data.edge_index[0]] | hit_mask[data.edge_index[1]] | ||
true_edge_mask = ( | ||
hit_mask[data.true_edge_index[0]] & hit_mask[data.true_edge_index[1]] | ||
) | ||
# Factor of 2 because the true edges are undirected | ||
efficiency = data.y[edge_mask].sum() / (2 * true_edge_mask.sum()) | ||
purity = data.y[edge_mask].sum() / edge_mask.sum() | ||
return { | ||
"efficiency": efficiency.item(), | ||
"purity": purity.item(), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
# import torchmetrics | ||
from typing import Any | ||
|
||
from torch import Tensor | ||
from torch import Tensor as T | ||
from torch_geometric.data import Data | ||
|
||
from torch.nn import CrossEntropyLoss, BCELoss | ||
from gnn_tracking.metrics.noise_classification import get_fp_pt | ||
from gnn_tracking.training.base import TrackingModule | ||
from gnn_tracking.utils.dictionaries import add_key_suffix, to_floats | ||
from gnn_tracking.utils.lightning import obj_from_or_to_hparams | ||
from gnn_tracking.utils.oom import tolerate_some_oom_errors | ||
|
||
|
||
class NodeClassifierModule(TrackingModule): | ||
# noinspection PyUnusedLocal | ||
def __init__( | ||
self, | ||
*, | ||
loss_fct: torch.nn, | ||
**kwargs, | ||
): | ||
"""Pytorch lightning module with training and validation step for the metric | ||
learning approach to graph construction. | ||
""" | ||
super().__init__(**kwargs) | ||
self.loss_fct: BCELoss | CrossEntropyLoss = obj_from_or_to_hparams( | ||
self, "loss_fct", loss_fct | ||
) | ||
# self.valid_acc = torchmetrics.Accuracy(task="binary") | ||
# self.roc = torchmetrics.ROC(task='binary') | ||
|
||
# noinspection PyUnusedLocal | ||
def get_losses(self, out: dict[str, Any], data: Data) -> tuple[T, dict[str, float]]: | ||
# targets = torch.vstack([data.particle_id == 0, data.particle_id != 0]).type(torch.LongTensor).to('cuda') | ||
targets = torch.tensor(list(zip(data.particle_id == 0, data.particle_id != 0))).type(torch.FloatTensor).to('cuda') | ||
loss = self.loss_fct(out["H"], targets) | ||
metrics = {} | ||
metrics["total"] = float(loss) | ||
return loss, metrics | ||
|
||
@tolerate_some_oom_errors | ||
def training_step(self, batch: Data, batch_idx: int) -> Tensor | None: | ||
batch = self.data_preproc(batch) | ||
out = self(batch, _preprocessed=True) | ||
loss, loss_dct = self.get_losses(out, batch) | ||
self.log_dict( | ||
add_key_suffix(loss_dct, "_train"), | ||
prog_bar=True, | ||
on_step=True, | ||
batch_size=self.trainer.train_dataloader.batch_size, | ||
) | ||
return loss | ||
|
||
def validation_step(self, batch: Data, batch_idx: int): | ||
batch = self.data_preproc(batch) | ||
out = self(batch, _preprocessed=True) | ||
loss, loss_dct = self.get_losses(out, batch) | ||
self.log_dict( | ||
add_key_suffix(loss_dct, "_val"), | ||
prog_bar=True, | ||
on_step=True, | ||
batch_size=self.trainer.val_dataloaders.batch_size | ||
) | ||
metrics = {} | ||
metrics["fp_pt"] = get_fp_pt(batch, self.model) | ||
self.log_dict( | ||
metrics, | ||
batch_size=self.trainer.val_dataloaders.batch_size | ||
) | ||
|
||
def on_validation_epoch_end(self) -> None: | ||
pass | ||
|
||
def highlight_metric(self, metric: str) -> bool: | ||
return metric in [ | ||
"n_edges_frac_segment50_95", | ||
"total", | ||
"attractive", | ||
"repulsive", | ||
"max_frac_segment50", | ||
] |