In [3]:
import yaml
from pathlib import Path
from utils import get_data_loaders, load_checkpoint, log_epoch, Criterion, add_cuts_to_config
import torch
from models import Model
from tqdm import tqdm
import pandas as pd

In [4]:
cuda_id = 3
log_name = '05_16_2022_17_06_28-GNN_half_dR_1'  # log id of the saved model to load
setting = log_name.split('-')[1]

config = yaml.safe_load(Path(f'./configs/{setting}.yml').open('r'))
device = torch.device(f'cuda:{cuda_id}' if cuda_id >= 0 else 'cpu')
log_path = Path(config['data']['log_dir']) / log_name

if len(log_name.split('-')) > 2:
    cut_id = log_name.split('-')[2]
    config = add_cuts_to_config(config, cut_id)

In [5]:
data_loaders, x_dim, edge_attr_dim, dataset = get_data_loaders(setting, config['data'], config['optimizer']['batch_size'])

[Splits]
    train: 332406. # pos: 55401, # neg: 277005. Pos:Neg: 0.200
    valid: 71226. # pos: 11871, # neg: 59355. Pos:Neg: 0.200
    test: 674713. # pos: 11873, # neg: 662840. Pos:Neg: 0.018


In [6]:
df = pd.read_pickle(dataset.get_df_save_path())
len(df), len(dataset)

(578745, 1078345)

In [7]:
model = Model(x_dim, edge_attr_dim, config['data']['virtual_node'], config['model']).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config['optimizer']['lr'])
load_checkpoint(model, optimizer, log_path, device)
criterion = Criterion(config['optimizer'])

[INFO] Loading checkpoint from 05_16_2022_17_06_28-GNN_half_dR_1


In [8]:
@torch.no_grad()
def eval_one_batch(data, model, criterion):
    model.eval()

    clf_logits = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, batch=data.batch, data=data)
    loss, loss_dict = criterion(clf_logits.sigmoid(), data.y)
    return loss_dict, clf_logits.data.cpu()


def run_one_epoch(data_loader, epoch, phase, device, model, criterion):
    loader_len = len(data_loader)
    run_one_batch = eval_one_batch
    phase = 'test ' if phase == 'test' else phase  # align tqdm desc bar

    all_loss_dict, all_clf_logits, all_clf_labels, all_sample_idx, all_endcap = {}, [], [], [], []
    pbar = tqdm(data_loader, total=loader_len)
    for idx, data in enumerate(pbar):
        loss_dict, clf_logits = run_one_batch(data.to(device), model, criterion)

        desc = log_epoch(epoch, phase, loss_dict, clf_logits, data.y.data.cpu(), batch=True)
        for k, v in loss_dict.items():
            all_loss_dict[k] = all_loss_dict.get(k, 0) + v
        all_clf_logits.append(clf_logits), all_clf_labels.append(data.y.data.cpu())
        all_sample_idx.append(data.sample_idx.data.cpu()), all_endcap.append(data.endcap.data.cpu())

        if idx == loader_len - 1:
            all_clf_logits, all_clf_labels = torch.cat(all_clf_logits), torch.cat(all_clf_labels)
            all_sample_idx, all_endcap = torch.cat(all_sample_idx), torch.cat(all_endcap)
            for k, v in all_loss_dict.items():
                all_loss_dict[k] = v / loader_len
            desc, auroc, recall, avg_loss = log_epoch(epoch, phase, all_loss_dict, all_clf_logits, all_clf_labels, False, None)
        pbar.set_description(desc)

    return avg_loss, auroc, recall, all_clf_logits, all_sample_idx, all_endcap

In [9]:
clf_probs, all_sample_idx, all_endcap = [], [], []
for phase in ['train', 'valid', 'test']:
    avg_loss, auroc, recall, clf_logits, sample_idx, endcap = run_one_epoch(data_loaders[phase], 999, phase, device, model, criterion)
    clf_probs.append(clf_logits.sigmoid())
    all_sample_idx.append(sample_idx)
    all_endcap.append(endcap)
    
clf_probs = torch.cat(clf_probs)
all_sample_idx = torch.cat(all_sample_idx)
all_endcap = torch.cat(all_endcap)

[Epoch: 999]: train finished, focal: 0.006, total: 0.006, auroc: 0.659, recall@maxfpr: 0.003: 100%|██████████| 1299/1299 [01:18<00:00, 16.64it/s]
[Epoch: 999]: valid finished, focal: 0.006, total: 0.006, auroc: 0.664, recall@maxfpr: 0.002: 100%|██████████| 279/279 [00:15<00:00, 18.12it/s]
[Epoch: 999]: test  finished, focal: 0.005, total: 0.005, auroc: 0.664, recall@maxfpr: 0.004: 100%|██████████| 2636/2636 [02:12<00:00, 19.89it/s]


In [11]:
scores = pd.DataFrame({'sample_idx': all_sample_idx, 'probs': clf_probs.reshape(-1), 'endcap': all_endcap})
scores['score_dict'] = scores.apply(lambda x: {x['endcap']: x['probs']}, axis=1)
scores = scores.sort_values('sample_idx').reset_index(drop=True)

In [29]:
def agg_endcap(x):
    res = {}
    for each in x:
        res = res | each
    return res

In [33]:
agg_scores = scores.groupby('sample_idx')['score_dict'].agg(agg_endcap)
assert len(agg_scores) == len(df)

In [42]:
agg_scores.to_pickle(dataset.get_df_save_path().parent / f'{setting}-scores.pkl')

In [44]:
agg_scores

sample_idx
0                                {-1.0: 0.5348360538482666}
1                                {-1.0: 0.4922066926956177}
2                                 {1.0: 0.5420222282409668}
3                                {-1.0: 0.5099743604660034}
4                                {-1.0: 0.5291178822517395}
                                ...                        
578740    {-1.0: 0.41937151551246643, 1.0: 0.48859730362...
578741    {1.0: 0.47036778926849365, -1.0: 0.50370985269...
578742    {1.0: 0.5109522342681885, -1.0: 0.486385315656...
578743    {-1.0: 0.5223032236099243, 1.0: 0.527758598327...
578744    {1.0: 0.5292389392852783, -1.0: 0.486385315656...
Name: score_dict, Length: 578745, dtype: object