In [7]:
from typing import DefaultDict

import pandas as pd

from gnn_tracking.metrics.binary_classification import BinaryClassificationStats
from gnn_tracking_hpo.util.paths import add_scripts_path

import matplotlib.pyplot as plt
import copy
import numpy as np
from gnn_tracking.graph_construction.graph_builder import load_graphs
from torch_geometric.loader import DataLoader
import torch
from tqdm import tqdm
add_scripts_path()
import collections


In [8]:
from tune_ec import ECTrainable

## Load model

In [9]:
trainable = ECTrainable.reinstate("ec_tc04", "0450a5df", epoch=-1, config_override={"n_graphs_train": 1, "n_graphs_test": 100})

[36m[13:50:06 gnnt_hpo] DEBUG: Loading config from /home/kl5675/ray_results/ec_tc04/ECTrainable_0450a5df_1_val_batch_size=40,adam_amsgrad=False,adam_beta1=0.9682,adam_beta2=0.9986,adam_eps=0.0000,adam_weight_decay=_2023-03-09_12-16-31/params.json[0m
[32m[13:50:06 gnnt_hpo] INFO: I'm running on a node with job ID=46112165[0m
[32m[13:50:06 gnnt_hpo] INFO: The ID of my dispatcher is 0[0m
[36m[13:50:06 gnnt_hpo] DEBUG: Got config
┌──────────────────────────────────────┬───────────────────────────────────────────────────────────────┐
│ _val_batch_size                      │ 40                                                            │
│ adam_amsgrad                         │ False                                                         │
│ adam_beta1                           │ 0.9681533655563015                                            │
│ adam_beta2                           │ 0.9986372465566375                                            │
│ adam_eps                            

In [22]:


from gnn_tracking.utils.timing import Timer
from gnn_tracking.utils.nomenclature import denote_pt
from gnn_tracking.metrics.binary_classification import get_maximized_bcs

from gnn_tracking.metrics.binary_classification import roc_auc_score


@torch.no_grad()
def single_test_step(
    self, val=True, apply_truth_cuts=False, max_batches= None
) -> dict[str, float]:
    """Test the model on the validation or test set

    Args:
        val: Use validation dataset rather than test dataset
        apply_truth_cuts: Apply truth cuts (e.g., truth level pt cut) during
            the evaluation
        max_batches: Only process this many batches per epoch (useful for testing)

    Returns:
        Dictionary of metrics
    """
    self.model.eval()

    # We connect part of the data in CPU memory for clustering & evaluation
    cluster_eval_input: DefaultDict[
        str, list[np.ndarray]
    ] = collections.defaultdict(list)

    batch_metrics = collections.defaultdict(list)
    loader = self.val_loader if val else self.test_loader
    timer = Timer()
    for _batch_idx, data in enumerate(loader):
        if max_batches and _batch_idx > max_batches:
            break
        timer()
        data = data.to(self.device)
        print("to", timer())
        model_output = self.evaluate_model(
            data, mask_pids_reco=False, apply_truth_cuts=apply_truth_cuts
        )
        print("eval", timer())
        batch_loss, these_batch_losses = self.get_batch_losses(model_output)
        print("losses", timer())
        batch_metrics["total"].append(batch_loss.item())
        for key, value in these_batch_losses.items():
            batch_metrics[key].append(value.item())
            batch_metrics[f"{key}_weighted"].append(
                value.item() * self._loss_weight_setter[key]
            )
        for key, value in self.evaluate_ec_metrics(
            model_output,
        ).items():
            batch_metrics[key].append(value)
        print("ec metrics", timer())



    # Merge all metrics in one big dictionary
    metrics: dict[str, float] = (
        {k: np.nanmean(v) for k, v in batch_metrics.items()}
        | {
            f"{k}_std": np.nanstd(v, ddof=1).item()
            for k, v in batch_metrics.items()
        }
        | self._evaluate_cluster_metrics(cluster_eval_input)
    )

    self.test_loss.append(pd.DataFrame(metrics, index=[self._epoch]))
    for hook in self._test_hooks:
        hook(self, metrics)
    return metrics



@torch.no_grad()
def evaluate_ec_metrics_with_pt_thld(
    self, model_output: dict[str, torch.Tensor], pt_min: float, ec_threshold: float
) -> dict[str, float]:
    """Evaluate edge classification metrics for a given pt threshold and
    EC threshold.

    Args:
        model_output: Output of the model
        pt_min: pt threshold: We discard all edges where both nodes have
            `pt <= pt_min` before evaluating any metric.
        ec_threshold: EC threshold

    Returns:
        Dictionary of metrics
    """
    timer = Timer()
    edge_pt_mask = self._edge_pt_mask(
        model_output["edge_index"], model_output["pt"], pt_min
    )
    print("edge_pt_mask", timer())
    predicted = model_output["w"][edge_pt_mask]
    true = model_output["y"][edge_pt_mask].long()
    print("retrieve", timer())

    bcs = BinaryClassificationStats(
        output=predicted,
        y=true,
        thld=ec_threshold,
    )
    metrics = bcs.get_all()
    print("bcs", timer())

    metrics |= get_maximized_bcs(output=predicted, y=true)
    print("maximized", timer())

    from torchmetrics.classification import BinaryAUROC

    metrics["roc_auc"] = BinaryAUROC()(preds=predicted, target=true).item()
    for max_fpr in [
        0.001,
        0.01,
        0.1,
    ]:
        metrics[f"roc_auc_{max_fpr}FPR"] = BinaryAUROC(max_fpr=max_fpr)(preds=predicted, target=true).item()
    print("roc", timer())
    return {denote_pt(k, pt_min): v for k, v in metrics.items()}

@torch.no_grad()
def evaluate_ec_metrics(
    self, model_output: dict[str, torch.Tensor], ec_threshold= None
) -> dict[str, float]:
    """Evaluate edge classification metrics for all pt thresholds."""
    if ec_threshold is None:
        ec_threshold = self.ec_threshold
    if model_output["w"] is None:
        return {}
    ret = {}
    for pt_min in self.ec_eval_pt_thlds:
        ret.update(
            self.evaluate_ec_metrics_with_pt_thld(
                model_output, pt_min, ec_threshold=ec_threshold
            )
        )
    return ret


In [23]:

trainable.trainer.single_test_step = single_test_step.__get__(trainable.trainer)
trainable.trainer.evaluate_ec_metrics = evaluate_ec_metrics.__get__(trainable.trainer)
trainable.trainer.evaluate_ec_metrics_with_pt_thld = evaluate_ec_metrics_with_pt_thld.__get__(trainable.trainer)

In [24]:
trainable.trainer.test_step()



to 0.008508029859513044
eval 0.27220815513283014
losses 0.04457372194156051
edge_pt_mask 7.415888831019402e-05
retrieve 0.0003614271990954876
bcs 0.0009067277424037457
maximized 0.13637115713208914
roc 0.010998062789440155
edge_pt_mask 7.122429087758064e-05
retrieve 0.0003327326849102974
bcs 0.0005731731653213501
maximized 0.09709221124649048
roc 0.007886095903813839
edge_pt_mask 7.050298154354095e-05
retrieve 0.0002961037680506706
bcs 0.00037442101165652275
maximized 0.061729167122393847
roc 0.005216727964580059
edge_pt_mask 6.903894245624542e-05
retrieve 0.0002932990901172161
bcs 0.00033126072958111763
maximized 0.054693051148205996
roc 0.004458197858184576
edge_pt_mask 6.843777373433113e-05
retrieve 0.00028972234576940536
bcs 0.0003193686716258526
maximized 0.05346680525690317
roc 0.0041329991072416306
ec metrics 0.44308759504929185
to 0.009095221292227507
eval 0.45257940562441945
losses 0.05233839340507984
edge_pt_mask 6.410013884305954e-05
retrieve 0.00042120786383748055
bcs 0.000

{'total': 0.00047760493665312725,
 'edge': 0.00047760493665312725,
 'edge_weighted': 0.00047760493665312725,
 'acc': 0.9837633310778052,
 'TPR': 0.5470888409619671,
 'TNR': 0.999328839469452,
 'FPR': 0.0006711605305479777,
 'FNR': 0.452911159038033,
 'balanced_acc': 0.7732088402157095,
 'F1': 0.6987311722693793,
 'MCC': 0.7207973697414006,
 'max_ba': 0.8551049386261279,
 'max_ba_loc': 0.28140702843666077,
 'max_f1': 0.728826307626561,
 'max_f1_loc': 0.428810715675354,
 'max_mcc': 0.7358555924619411,
 'max_mcc_loc': 0.44723618030548096,
 'tpr_eq_tnr': 0.843746644894534,
 'tpr_eq_tnr_loc': 0.18592964112758636,
 'roc_auc': 0.9439328908920288,
 'roc_auc_0.001FPR': 0.7535172502199808,
 'roc_auc_0.01FPR': 0.8133438030878702,
 'roc_auc_0.1FPR': 0.8655538360277811,
 'acc_pt0.5': 0.990975553924148,
 'TPR_pt0.5': 0.6469986095441717,
 'TNR_pt0.5': 0.9996246531279915,
 'FPR_pt0.5': 0.0003753468720084944,
 'FNR_pt0.5': 0.3530013904558283,
 'balanced_acc_pt0.5': 0.8233116313360815,
 'F1_pt0.5': 0.77

In [86]:
class Test:
    def test(self, a):
        print(f"a {a}")

In [87]:
Test().test(5)

a 5


In [88]:
def mytest(self, a):
    print(f"my test a {a}")

In [89]:
t = Test()
t.test = addself(mytest, t)

In [91]:
t.test(a=3)

my test a 3
