In [1]:
import numpy as np
import torch
from tqdm import tqdm
from torch_geometric.data import DataLoader

from cluster_gnn.models import gnn
from cluster_gnn.data import loader as loadtools

In [2]:
# configure the architecture
FINAL_BIAS = False
DIM_EMBED_EDGE = 64
DIM_EMBED_NODE = 32
ROOT_DIR = '/home/jlc1n20/projects/cluster_gnn/'

In [3]:
# load in the network
device = torch.device('cpu')
model = gnn.Net(dim_embed_edge=DIM_EMBED_EDGE,
                dim_embed_node=DIM_EMBED_NODE,
                final_bias=FINAL_BIAS).to(device)
model.load_state_dict(torch.load(ROOT_DIR + '/models/pos80lr14.pt', map_location=device))
model.eval()

Net(
  (conv1): Interaction(
    (mlp_edge): Sequential(
      (0): Linear(in_features=8, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (mlp_node): Sequential(
      (0): Linear(in_features=68, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
    )
  )
  (conv2): Interaction(
    (mlp_edge): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (mlp_node): Sequential(
      (0): Linear(in_features=96, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
    )
  )
  (conv3): Interaction(
    (mlp_edge): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (mlp_node): Sequential(
      

In [4]:
# load an iterator over the data
dataset = loadtools.EventDataset()
dataset = dataset[90000:90010]
loader = DataLoader(dataset, batch_size=1)

In [7]:
THRESH = 0.90
# loop through the data in evaluation mode, printing classifications
with torch.no_grad():
    for data in tqdm(loader):
        label = data.y.detach().numpy().astype(bool)
        num_pos = int(label.sum())
        num_neg = len(label) - num_pos
        output = torch.sigmoid(model(data))
        pred = np.where(output.detach().numpy() < THRESH, 0, 1).astype(bool).reshape(-1)
        key, val = np.unique(pred, return_counts=True)
        classes = {key[0]: val[0], key[1]: val[1]}
        true_pos = np.bitwise_and.reduce([pred, label], axis=0).sum()
        true_neg = np.bitwise_and.reduce([pred == False, label == False], axis=0).sum()
        false_pos = classes[True] - true_pos
        false_neg = classes[False] - true_neg
        print(f'Graph has {data.num_edges} edges, of which {num_pos} are positive, and {num_neg} are negative.')
        print(f'Positive labels: {true_pos} correctly identified, {false_neg} incorrectly marked as negative.')
        print(f'Negative labels: {true_neg} correctly identified, {false_pos} incorrectly marked as positive.')

 10%|█         | 1/10 [00:02<00:20,  2.26s/it]

Graph has 64770 edges, of which 756 are positive, and 64014 are negative.
Positive labels: 723 correctly identified, 33 incorrectly marked as negative.
Negative labels: 63412 correctly identified, 602 incorrectly marked as positive.


 20%|██        | 2/10 [00:06<00:25,  3.23s/it]

Graph has 250500 edges, of which 4830 are positive, and 245670 are negative.
Positive labels: 4084 correctly identified, 746 incorrectly marked as negative.
Negative labels: 241761 correctly identified, 3909 incorrectly marked as positive.


 30%|███       | 3/10 [00:13<00:35,  5.13s/it]

Graph has 487902 edges, of which 1482 are positive, and 486420 are negative.
Positive labels: 1452 correctly identified, 30 incorrectly marked as negative.
Negative labels: 480726 correctly identified, 5694 incorrectly marked as positive.


 40%|████      | 4/10 [00:16<00:26,  4.35s/it]

Graph has 105300 edges, of which 930 are positive, and 104370 are negative.
Positive labels: 921 correctly identified, 9 incorrectly marked as negative.
Negative labels: 103086 correctly identified, 1284 incorrectly marked as positive.


 50%|█████     | 5/10 [00:21<00:23,  4.64s/it]

Graph has 293222 edges, of which 1332 are positive, and 291890 are negative.
Positive labels: 1249 correctly identified, 83 incorrectly marked as negative.
Negative labels: 290290 correctly identified, 1600 incorrectly marked as positive.


 60%|██████    | 6/10 [00:26<00:18,  4.59s/it]

Graph has 126380 edges, of which 1560 are positive, and 124820 are negative.
Positive labels: 1394 correctly identified, 166 incorrectly marked as negative.
Negative labels: 123595 correctly identified, 1225 incorrectly marked as positive.


 70%|███████   | 7/10 [00:29<00:12,  4.01s/it]

Graph has 115260 edges, of which 870 are positive, and 114390 are negative.
Positive labels: 658 correctly identified, 212 incorrectly marked as negative.
Negative labels: 113763 correctly identified, 627 incorrectly marked as positive.


 80%|████████  | 8/10 [00:32<00:07,  3.79s/it]

Graph has 153272 edges, of which 870 are positive, and 152402 are negative.
Positive labels: 813 correctly identified, 57 incorrectly marked as negative.
Negative labels: 150259 correctly identified, 2143 incorrectly marked as positive.


 90%|█████████ | 9/10 [00:34<00:03,  3.13s/it]

Graph has 63252 edges, of which 2070 are positive, and 61182 are negative.
Positive labels: 2070 correctly identified, 0 incorrectly marked as negative.
Negative labels: 58767 correctly identified, 2415 incorrectly marked as positive.


100%|██████████| 10/10 [00:39<00:00,  3.98s/it]

Graph has 304152 edges, of which 1056 are positive, and 303096 are negative.
Positive labels: 1050 correctly identified, 6 incorrectly marked as negative.
Negative labels: 301734 correctly identified, 1362 incorrectly marked as positive.





In [6]:
# mean = 0.0
# with torch.no_grad():
#     for data in tqdm(loader):
#         mean += float(len(data.y) - data.y.sum()) / float(data.y.sum())
        
# mean = mean / float(len(loader))
# print(mean)