In [28]:
import yaml
from pathlib import Path
from utils import Tau3MuDataset, Root2Df, get_data_loaders, load_checkpoint, log_epoch, Criterion
import matplotlib.pyplot as plt
import numpy as np
import torch
from models import Model
from tqdm import tqdm
from torch_geometric.loader import DataLoader

In [2]:
setting = 'GNN-full-dR-1'
cuda_id = 7
log_name = '03_06_2022-16_34_42-GNN-full-dR-1'

config = yaml.safe_load(Path(f'./configs/{setting}.yml').open('r'))
device = torch.device(f'cuda:{cuda_id}' if cuda_id >= 0 else 'cpu')
log_path = Path(config['data']['log_dir']) / log_name

In [65]:
def filter_samples(x):
    p = np.sqrt(x['gen_mu_e']**2 - 0.1057**2 + 1e-5)
    pt = np.array(x['gen_mu_pt'])
    abs_eta = np.abs(x['gen_mu_eta'])

    cut_1 = ((p > 2.5).sum() == 3) and ((pt > 0.5).sum() == 3) and ((abs_eta < 2.8).sum() == 3)
    cut_2 = ((pt > 2.0).sum() >= 1) and ((abs_eta < 2.4).sum() >= 1)
    return (cut_1) or x['y'] == 0

In [3]:
data_loaders, x_dim, edge_attr_dim, dataset = get_data_loaders(setting, config['data'], config['optimizer']['batch_size'])

[Splits]
    train: 332406. # pos: 55401, # neg: 277005. Pos:Neg: 0.200
    valid: 71226. # pos: 11871, # neg: 59355. Pos:Neg: 0.200
    test: 175113. # pos: 11873, # neg: 163240. Pos:Neg: 0.073


In [66]:
cut_1 = dataset.df.apply(lambda x: filter_samples(x), axis=1)

In [73]:
test_cut_1 = list(set(dataset.df[cut_1 == True].index.to_list()).intersection(dataset.idx_split['test']))

In [74]:
test_set = dataset.copy(test_cut_1)
test_loader = DataLoader(test_set, batch_size=config['optimizer']['batch_size'], shuffle=False)

In [75]:
test_set.data.y.sum()

tensor(1551.)

In [4]:
model = Model(x_dim, edge_attr_dim, config['data']['virtual_node'], config['model']).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config['optimizer']['lr'])
load_checkpoint(model, optimizer, log_path, device)
criterion = Criterion(config['optimizer'])

[INFO] Loading checkpoint from 03_06_2022-16_34_42-GNN-full-dR-1


In [5]:
@torch.no_grad()
def eval_one_batch(data, model, criterion):
    model.eval()

    clf_logits = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, batch=data.batch, data=data)
    loss, loss_dict = criterion(clf_logits.sigmoid(), data.y)
    return loss_dict, clf_logits.data.cpu()


def run_one_epoch(data_loader, epoch, phase, device, model, criterion):
    loader_len = len(data_loader)
    run_one_batch = eval_one_batch
    phase = 'test ' if phase == 'test' else phase  # align tqdm desc bar

    all_loss_dict, all_clf_logits, all_clf_labels = {}, [], []
    pbar = tqdm(data_loader, total=loader_len)
    for idx, data in enumerate(pbar):
        loss_dict, clf_logits = run_one_batch(data.to(device), model, criterion)

        desc = log_epoch(epoch, phase, loss_dict, clf_logits, data.y.data.cpu(), batch=True)
        for k, v in loss_dict.items():
            all_loss_dict[k] = all_loss_dict.get(k, 0) + v
        all_clf_logits.append(clf_logits), all_clf_labels.append(data.y.data.cpu())

        if idx == loader_len - 1:
            all_clf_logits, all_clf_labels = torch.cat(all_clf_logits), torch.cat(all_clf_labels)
            for k, v in all_loss_dict.items():
                all_loss_dict[k] = v / loader_len
            desc, auroc, recall, avg_loss = log_epoch(epoch, phase, all_loss_dict, all_clf_logits, all_clf_labels, False, None)
        pbar.set_description(desc)

    return avg_loss, auroc, recall

In [76]:
run_one_epoch(test_loader, 999, 'test', device, model, criterion)

[Epoch: 999]: test  finished, focal: 0.000, total: 0.000, auroc: 0.996, recall@maxfpr: 0.901: 100%|██████████| 644/644 [00:12<00:00, 52.30it/s]


(0.00020636552601205061, 0.9955285979546042, 0.900709219858156)