# Conversion to Lightning

In [1]:
# System imports
import os
import sys
from pprint import pprint as pp
from time import time as tt
import inspect
import importlib

# External imports
import matplotlib.pyplot as plt
import matplotlib.colors
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch_geometric.data import DataLoader
from mpl_toolkits.mplot3d import Axes3D
from random import shuffle, sample
from torch_scatter import scatter_add

from itertools import chain

from torch.nn import Linear
import torch.nn.functional as F
from torch_scatter import scatter, segment_csr, scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_cluster import knn_graph, radius_graph
import trackml.dataset
import torch_geometric
from itertools import permutations
import itertools
import plotly.express as px

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

from sklearn.cluster import DBSCAN
from sklearn import metrics
from torch.utils.checkpoint import checkpoint

# Limit CPU usage on Jupyter
os.environ['OMP_NUM_THREADS'] = '4'

# Pick up local packages
sys.path.append('..')

# Local imports
from prepare_utils import *
from performance_utils import *
from toy_utils import *
from models import *
from trainers import *
%matplotlib inline


# Get rid of RuntimeWarnings, gross
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import wandb
import faiss
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
torch_seed = 0

In [10]:
importlib.reload(sys.modules['toy_utils'])
from toy_utils import *

In [43]:
importlib.reload(sys.modules['models'])
from models import *

## Data Preparation

### Load Scrubbed Filter-ready Events

In [2]:
pt_cut = 0.5
train_number = 1000
test_number = 100
load_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/filter_processed/"
basename = os.path.join(load_dir, str(pt_cut) + "_pt_cut")
train_path = os.path.join(basename, str(train_number) + "_events_train.pkl")
test_path = os.path.join(basename, str(test_number) + "_events_test.pkl")

In [3]:
%%time 
train_dataset = torch.load(train_path)
test_dataset = torch.load(test_path)
train_loader = DataLoader(train_dataset[:], batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset[:], batch_size=1, shuffle=True)

CPU times: user 244 ms, sys: 12.2 s, total: 12.5 s
Wall time: 15.7 s


## GNN Memory Tests

In [4]:
class EdgeNetwork(nn.Module):
    """
    A module which computes weights for edges of the graph.
    For each edge, it selects the associated nodes' features
    and applies some fully-connected network layers with a final
    sigmoid activation.
    """
    def __init__(self, input_dim, hidden_dim=8, hidden_activation='Tanh',
                 layer_norm=True):
        super(EdgeNetwork, self).__init__()
        self.network = make_mlp(input_dim*2,
                                [hidden_dim, hidden_dim, hidden_dim, 1],
                                hidden_activation=hidden_activation,
                                output_activation=None,
                                layer_norm=layer_norm)

    def forward(self, inputs):
        # Select the features of the associated nodes
        start, end = inputs[1]
        edge_inputs = torch.cat([inputs[0][start], inputs[0][end]], dim=1)
        return self.network(edge_inputs).squeeze(-1)

class NodeNetwork(nn.Module):
    """
    A module which computes new node features on the graph.
    For each node, it aggregates the neighbor node features
    (separately on the input and output side), and combines
    them with the node's previous features in a fully-connected
    network to compute the new features.
    """
    def __init__(self, input_dim, output_dim, hidden_activation='Tanh',
                 layer_norm=True):
        super(NodeNetwork, self).__init__()
        self.network = make_mlp(input_dim*3, [output_dim]*4,
                                hidden_activation=hidden_activation,
                                output_activation=None,
                                layer_norm=layer_norm)

    def forward(self, inputs):
        start, end = inputs[2]
        # Aggregate edge-weighted incoming/outgoing features
        mi = scatter_add(inputs[1][:, None] * inputs[0][start], end, dim=0, dim_size=inputs[0].shape[0])
        mo = scatter_add(inputs[1][:, None] * inputs[0][end], start, dim=0, dim_size=inputs[0].shape[0])
        node_inputs = torch.cat([mi, mo, inputs[0]], dim=1)
        return self.network(node_inputs)

class ResAGNN(nn.Module):
    """
    Segment classification graph neural network model.
    Consists of an input network, an edge network, and a node network.
    """
    def __init__(self, in_channels=3, hidden_dim=8, n_graph_iters=3,
                 hidden_activation=torch.nn.Tanh, layer_norm=True):
        super(ResAGNN, self).__init__()
        self.n_graph_iters = n_graph_iters
        # Setup the input network
        self.input_network = make_mlp(in_channels, [hidden_dim],
                                      output_activation=hidden_activation,
                                      layer_norm=layer_norm)
        # Setup the edge network
        self.edge_network = EdgeNetwork(in_channels + hidden_dim, in_channels + hidden_dim,
                                        hidden_activation, layer_norm=layer_norm)
        # Setup the node layers
        self.node_network = NodeNetwork(in_channels + hidden_dim, hidden_dim,
                                        hidden_activation, layer_norm=layer_norm)
        
    def forward(self, x, edge_index):
        """Apply forward pass of the model"""
        input_x = x
        x = self.input_network(x)
        # Shortcut connect the inputs onto the hidden representation
        x = torch.cat([x, input_x], dim=-1)
#         print(x.shape)
        
        # Loop over iterations of edge and node networks
        for i in range(self.n_graph_iters):
            x_inital = x
            
            # Apply edge network
            e = torch.sigmoid(self.edge_network((x, edge_index)))
        
            # Apply node network
            x = self.node_network((x, e, edge_index))
            
            # Shortcut connect the inputs onto the hidden representation
            x = torch.cat([x, input_x], dim=-1)  
            
            x = x_inital + x
        
        return self.edge_network((x, edge_index))


class CheckResAGNN(nn.Module):
    """
    Segment classification graph neural network model.
    Consists of an input network, an edge network, and a node network.
    """
    def __init__(self, in_channels=3, hidden_dim=8, n_graph_iters=3,
                 hidden_activation=torch.nn.Tanh, layer_norm=True):
        super(CheckResAGNN, self).__init__()
        self.n_graph_iters = n_graph_iters
        # Setup the input network
        self.input_network = make_mlp(in_channels, [hidden_dim],
                                      output_activation=hidden_activation,
                                      layer_norm=layer_norm)
        # Setup the edge network
        self.edge_network = EdgeNetwork(in_channels + hidden_dim, in_channels + hidden_dim,
                                        hidden_activation, layer_norm=layer_norm)
        # Setup the node layers
        self.node_network = NodeNetwork(in_channels + hidden_dim, hidden_dim,
                                        hidden_activation, layer_norm=layer_norm)
        self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
        
    def custom_forward(self, inputs):
        # Apply edge network
        print("x in check (before)", inputs[0], inputs[0].requires_grad)
        e = self.edge_network((inputs[0], inputs[1]))
        e = torch.sigmoid(e)
        x = self.node_network((inputs[0], e, inputs[1]))
        print("e in check", e, e.requires_grad)
        print("x in check (after)", x, x.requires_grad)
        # Apply node network
        return x
    
    def forward(self, x, edge_index):
        """Apply forward pass of the model"""
        input_x = x
        x = self.input_network(x)
        # Shortcut connect the inputs onto the hidden representation
        x = torch.cat([x, input_x], dim=-1)
#         print(x.shape)
        
        # Loop over iterations of edge and node networks
        for i in range(self.n_graph_iters):
            
            x_inital = x
            print("x before check", x, x.requires_grad)
            x = checkpoint(self.custom_forward, (x, edge_index, self.dummy_tensor))
#             x = self.custom_forward((x, edge_index))
            print("x after check", x, x.requires_grad)
#             x.requires_grad_(True)
            # Apply edge network
#             e = checkpoint(self.edge_network, (x, edge_index))
#             e = self.edge_network((x, edge_index))
#             e = torch.sigmoid(e)
        
            # Apply node network
#             x = checkpoint(self.node_network, (x, e, edge_index))
#             x = self.node_network((x, e, edge_index))
            
            # Shortcut connect the inputs onto the hidden representation
            x = torch.cat([x, input_x], dim=-1)  
            
            x = x_inital + x
            
        e = self.edge_network((x, edge_index))
#         e = torch.sigmoid(e)
        
        return e
    
class MPNN_Network(nn.Module):
    """
    A message-passing graph network which takes a graph with:
    - bi-directional edges
    - node features, no edge features

    and applies the following modules:
    - a graph encoder (no message passing)
    - recurrent edge and node networks
    - an edge classifier
    """

    def __init__(self, input_dim, hidden_node_dim, hidden_edge_dim, in_layers, node_layers, edge_layers,
                 n_graph_iters=1, layer_norm=True):
        super(MPNN_Network, self).__init__()
        self.n_graph_iters = n_graph_iters

        # The node encoder transforms input node features to the hidden space
        self.node_encoder = make_mlp(input_dim, [hidden_node_dim]*in_layers)

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(2*hidden_node_dim,
                                     [hidden_edge_dim]*edge_layers,
                                     layer_norm=layer_norm)

        # The node network computes new node features
        self.node_network = make_mlp(hidden_node_dim + hidden_edge_dim,
                                     [hidden_node_dim]*node_layers,
                                     layer_norm=layer_norm)

        # The edge classifier computes final edge scores
        self.edge_classifier = make_mlp(2*hidden_node_dim,
                                        [hidden_edge_dim, 1],
                                        output_activation=None)

    def forward(self, x, edge_index):

        # Encode the graph features into the hidden space
        x = self.node_encoder(x)

        # Loop over graph iterations
        for i in range(self.n_graph_iters):

            # Previous hidden state
            x0 = x

            # Compute new edge features
            edge_inputs = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=1)
            e = self.edge_network(edge_inputs)

            # Sum edge features coming into each node
            aggr_messages = scatter_add(e, edge_index[1], dim=0, dim_size=x.shape[0])

            # Compute new node features
            node_inputs = torch.cat([x, aggr_messages], dim=1)
            x = self.node_network(node_inputs)

            # Residual connection
            x = x + x0

        # Compute final edge scores; use original edge directions only
        start_idx, end_idx = edge_index
        clf_inputs = torch.cat([x[start_idx], x[end_idx]], dim=1)
        return self.edge_classifier(clf_inputs).squeeze(-1)
    
class CheckMPNN_Network(nn.Module):
    """
    A message-passing graph network which takes a graph with:
    - bi-directional edges
    - node features, no edge features

    and applies the following modules:
    - a graph encoder (no message passing)
    - recurrent edge and node networks
    - an edge classifier
    """

    def __init__(self, input_dim, hidden_node_dim, hidden_edge_dim, in_layers, node_layers, edge_layers,
                 n_graph_iters=1, layer_norm=True):
        super(CheckMPNN_Network, self).__init__()
        self.n_graph_iters = n_graph_iters

        # The node encoder transforms input node features to the hidden space
        self.node_encoder = make_mlp(input_dim, [hidden_node_dim]*in_layers)

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(2*hidden_node_dim,
                                     [hidden_edge_dim]*edge_layers,
                                     layer_norm=layer_norm)

        # The node network computes new node features
        self.node_network = make_mlp(hidden_node_dim + hidden_edge_dim,
                                     [hidden_node_dim]*node_layers,
                                     layer_norm=layer_norm)

        # The edge classifier computes final edge scores
        self.edge_classifier = make_mlp(2*hidden_node_dim,
                                        [hidden_edge_dim, 1],
                                        output_activation=None)

    def forward(self, x, edge_index):

        # Encode the graph features into the hidden space
        x = self.node_encoder(x)

        # Loop over graph iterations
        for i in range(self.n_graph_iters):

            # Previous hidden state
            x0 = x

            # Compute new edge features
            edge_inputs = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=1)
            e = checkpoint(self.edge_network, edge_inputs)

            # Sum edge features coming into each node
            aggr_messages = scatter_add(e, edge_index[1], dim=0, dim_size=x.shape[0])

            # Compute new node features
            node_inputs = torch.cat([x, aggr_messages], dim=1)
            x = checkpoint(self.node_network, node_inputs)

            # Residual connection
            x = x + x0

        # Compute final edge scores; use original edge directions only
        start_idx, end_idx = edge_index
        clf_inputs = torch.cat([x[start_idx], x[end_idx]], dim=1)
        return checkpoint(self.edge_classifier, clf_inputs).squeeze(-1)

In [5]:
class Filter(torch.nn.Module):
    def __init__(self, in_channels, emb_channels, hidden, nb_layer):
        super(Filter, self).__init__()
        self.input_layer = Linear(in_channels*2 + emb_channels*2, hidden)
        layers = [Linear(hidden, hidden) for _ in range(nb_layer-1)]
        self.layers = nn.ModuleList(layers)
        self.output_layer = nn.Linear(hidden, 1)
        self.norm = nn.LayerNorm(hidden)
        self.act = nn.Tanh()

    def forward(self, x, e, emb=None):
        if emb is not None:
            x = self.input_layer(torch.cat([x[e[0]], emb[e[0]], x[e[1]], emb[e[1]]], dim=-1))
        else:
            x = self.input_layer(torch.cat([x[e[0]], x[e[1]]], dim=-1))
        for l in self.layers:
            x = l(x)
            x = self.act(x)
#         x = self.norm(x) #Option of LayerNorm
        x = self.output_layer(x)
        return x

In [5]:
def train_gnn(model, train_loader, optimizer, m_configs):
    
    total_loss = 0
    
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        if m_configs['ratio'] != 0:
            num_true = batch.y.sum()
            fake_indices = np.random.choice(np.where(~batch.y.bool())[0], int(num_true*m_configs['ratio']), replace=True)
            true_indices = np.where(batch.y.bool())[0]
            combined_indices = np.concatenate([true_indices, fake_indices])

            x, e = batch.x.to(device), batch.e_radius[:,combined_indices].to(device)
            weight = torch.tensor(m_configs['ratio'])
            combined_y = batch.y[combined_indices]   
        
        else:
            x, e = batch.x.to(device), batch.e_radius.to(device)
            weight = torch.tensor((~batch.y.bool()).sum() / batch.y.sum())
            combined_y = batch.y         
        
        output = model(x, e).squeeze()
        
        y_pid = (batch.pid[batch.e_radius[0]] == batch.pid[batch.e_radius[1]]).float()
                
#         loss = F.binary_cross_entropy_with_logits(output.cpu(), combined_y, pos_weight = weight)
        loss = F.binary_cross_entropy_with_logits(output.cpu(), y_pid, pos_weight = torch.tensor(3))
        total_loss += loss.item()
#         print(i, loss)
        loss.backward()
        optimizer.step()
    
    return total_loss

def evaluate_gnn(model, test_loader, m_configs):
    
    edge_true, edge_true_positive, edge_positive, edge_total_positive, edge_total_true, edge_total_true_positive = [0]*6
    total_loss = 0
    
    for i, batch in enumerate(test_loader):
            
#         subset_ind = np.random.choice(np.arange(0, batch.e_radius.shape[1]), int(m_configs['test_subset']*batch.e_radius.shape[1]), replace=False)
        if m_configs['ratio'] != 0:
            subset_ind = np.random.randint(0, batch.e_radius.shape[1], 10000)
        else:
            subset_ind = np.arange(0, batch.e_radius.shape[1])
        x, e = batch.x.to(device), batch.e_radius[:, subset_ind].to(device)
        output = model(x, e).squeeze()
        
        y_pid = (batch.pid[batch.e_radius[0]] == batch.pid[batch.e_radius[1]]).float()
        
#         loss = F.binary_cross_entropy_with_logits(output.cpu(), batch.y[subset_ind])
        loss = F.binary_cross_entropy_with_logits(output.cpu(), y_pid)
        total_loss += loss.item()
        
        preds = F.sigmoid(output.cpu()) > 0.5
        
        #Edge filter performance
#         edge_true = batch.y[subset_ind].sum()
        edge_true = y_pid.sum()
#         edge_true_positive = (batch.y[subset_ind].bool() & preds).sum()
        edge_true_positive = (y_pid.bool() & preds).sum()
        edge_positive = preds.sum()
        
        edge_total_true_positive += edge_true_positive
        edge_total_positive += edge_positive
        edge_total_true += edge_true
        
#         print("True positive:", edge_true_positive, "True:", edge_true, "Positive:", edge_positive)

    edge_eff = edge_total_true_positive.float() / max(edge_total_true.float(), 1)
    edge_pur = edge_total_true_positive.float() / max(edge_total_positive.float(), 1)

    return edge_pur, edge_eff, total_loss


### Benchmark Tests

In [6]:
torch.manual_seed(0)

# m_configs = {"in_channels": 3, "hidden_dim": 32, "n_graph_iters": 6}
# model = ResAGNN(**m_configs).to(device)

m_configs = {"input_dim": 3, "hidden_node_dim": 64, "hidden_edge_dim": 128, "in_layers": 2, "node_layers": 4, "edge_layers": 4, "n_graph_iters": 8, "layer_norm": True}
model = MPNN_Network(**m_configs).to(device)

other_configs = {'ratio': 0, 'test_subset': 0.001}
m_configs.update(other_configs)
model_name = wandb.init(project="CheckpointExploration", config=m_configs)
wandb.watch(model, log='all')
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-3, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=5)

torch.cuda.reset_max_memory_allocated()

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.9.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [7]:
model.train()

total_loss = 0
tic = tt()
for i, batch in enumerate(train_loader):
    optimizer.zero_grad()

    if m_configs['ratio'] != 0:
        num_true = batch.y.sum()
        fake_indices = np.random.choice(np.where(~batch.y.bool())[0], int(num_true*m_configs['ratio']), replace=True)
        true_indices = np.where(batch.y.bool())[0]
        combined_indices = np.concatenate([true_indices, fake_indices])

        x, e = batch.x.to(device), batch.e_radius[:,combined_indices].to(device)
        weight = torch.tensor(m_configs['ratio'])
        combined_y = batch.y[combined_indices]   

    else:
        x, e = batch.x.to(device), batch.e_radius.to(device)
        weight = torch.tensor((~batch.y.bool()).sum() / batch.y.sum())
        combined_y = batch.y         
    x.requires_grad_(True)
    
    output = model(x, e)

    loss = F.binary_cross_entropy_with_logits(output, combined_y.to(device), pos_weight = weight)
    total_loss += loss.item()
#         print(i, loss)
    loss.backward()
    optimizer.step()
    
    break
print("Memory max (GB):", torch.cuda.max_memory_allocated() / 1024**3, "- time:", tt()-tic)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.9.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Memory max (GB): 9.888021469116211 - time: 1.5440750122070312


In [61]:
output_not_checkpointed = output.data.clone()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_not_checkpointed[name] = param.grad.data.clone()
    else:
        print(name, "is NONE!")

In [65]:
output_not_checkpointed

tensor([0.6636, 0.6636, 0.6636,  ..., 0.9206, 1.1540, 1.1405], device='cuda:0')

In [66]:
output_checkpointed

tensor([0.6636, 0.6636, 0.6636,  ..., 0.9206, 1.1540, 1.1405], device='cuda:0')

In [67]:
grad_not_checkpointed

{'node_encoder.0.weight': tensor([[-7.4650e-04, -8.5498e-04,  8.2160e-04],
         [ 1.0996e-03, -7.5045e-03, -3.9679e-04],
         [ 6.0399e-05,  2.7517e-04, -3.9528e-05],
         [-8.7073e-05,  3.9059e-05,  4.5261e-05],
         [-3.1462e-06,  2.7884e-05,  3.5855e-06],
         [-4.9015e-04,  2.0558e-03,  5.0035e-04],
         [ 3.6316e-06, -1.3667e-05,  1.3668e-05],
         [-1.5317e-03,  4.7140e-03, -1.2496e-03],
         [-2.7532e-03, -1.6618e-03, -2.5133e-03],
         [ 2.3405e-03,  1.3391e-03, -1.2645e-03],
         [ 6.8052e-05, -4.7394e-04, -6.6557e-05],
         [ 1.7535e-04, -5.0981e-04, -2.7442e-04],
         [ 3.4337e-04,  2.8137e-03,  1.4704e-04],
         [-5.5038e-03, -7.2412e-04, -8.1829e-04],
         [-8.1277e-04,  7.7152e-04,  6.8225e-04],
         [ 2.6660e-03,  9.2370e-04, -1.0801e-03],
         [-8.9833e-05,  6.8392e-03, -8.2161e-04],
         [ 2.9409e-03, -1.0832e-03, -1.2232e-03],
         [-2.9338e-03,  6.9049e-03,  1.4955e-03],
         [-7.8155e-04,  4

In [68]:
grad_checkpointed

{'node_encoder.0.weight': tensor([[-7.4650e-04, -8.5498e-04,  8.2160e-04],
         [ 1.0996e-03, -7.5044e-03, -3.9677e-04],
         [ 6.0400e-05,  2.7517e-04, -3.9527e-05],
         [-8.7072e-05,  3.9058e-05,  4.5261e-05],
         [-3.1462e-06,  2.7884e-05,  3.5855e-06],
         [-4.9014e-04,  2.0558e-03,  5.0034e-04],
         [ 3.6316e-06, -1.3667e-05,  1.3668e-05],
         [-1.5317e-03,  4.7140e-03, -1.2496e-03],
         [-2.7531e-03, -1.6618e-03, -2.5133e-03],
         [ 2.3405e-03,  1.3391e-03, -1.2645e-03],
         [ 6.8052e-05, -4.7394e-04, -6.6557e-05],
         [ 1.7535e-04, -5.0981e-04, -2.7442e-04],
         [ 3.4337e-04,  2.8137e-03,  1.4704e-04],
         [-5.5038e-03, -7.2406e-04, -8.1829e-04],
         [-8.1278e-04,  7.7152e-04,  6.8226e-04],
         [ 2.6660e-03,  9.2366e-04, -1.0801e-03],
         [-8.9837e-05,  6.8392e-03, -8.2161e-04],
         [ 2.9409e-03, -1.0832e-03, -1.2232e-03],
         [-2.9338e-03,  6.9049e-03,  1.4955e-03],
         [-7.8155e-04,  4

In [90]:
print("Memory max (GB):", torch.cuda.max_memory_allocated() / 1024**3)

Memory max (GB): 1.2986721992492676


### Checkpointed Tests

In [12]:
torch.manual_seed(0)

# m_configs = {"in_channels": 3, "hidden_dim": 32, "n_graph_iters": 6}
# model = CheckResAGNN(**m_configs).to(device)

m_configs = {"input_dim": 3, "hidden_node_dim": 256, "hidden_edge_dim": 512, "in_layers": 2, "node_layers": 4, "edge_layers": 4, "n_graph_iters": 8, "layer_norm": True}
model = CheckMPNN_Network(**m_configs).to(device)

other_configs = {'ratio': 0, 'test_subset': 0.001}
m_configs.update(other_configs)
model_name = wandb.init(project="CheckpointExploration", config=m_configs)
wandb.watch(model, log='all')
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-3, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=5)

torch.cuda.reset_max_memory_allocated()

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.9.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [11]:
model.train()

total_loss = 0
tic = tt()
for i, batch in enumerate(train_loader):
    optimizer.zero_grad()

    if m_configs['ratio'] != 0:
        num_true = batch.y.sum()
        fake_indices = np.random.choice(np.where(~batch.y.bool())[0], int(num_true*m_configs['ratio']), replace=True)
        true_indices = np.where(batch.y.bool())[0]
        combined_indices = np.concatenate([true_indices, fake_indices])

        x, e = batch.x.to(device), batch.e_radius[:,combined_indices].to(device)
        weight = torch.tensor(m_configs['ratio'])
        combined_y = batch.y[combined_indices]   

    else:
        x, e = batch.x.to(device), batch.e_radius.to(device)
        weight = torch.tensor((~batch.y.bool()).sum() / batch.y.sum())
        combined_y = batch.y         
#     x.requires_grad_(True)
    output = model(x, e)
    
    loss = F.binary_cross_entropy_with_logits(output, combined_y.to(device), pos_weight = weight)
    total_loss += loss.item()
#         print(i, loss)
    loss.backward()
    
#     output.mean().backward()
    optimizer.step()
    
    break
print("Memory max (GB):", torch.cuda.max_memory_allocated() / 1024**3, "- time:", tt()-tic)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.9.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Memory max (GB): 6.213899612426758 - time: 1.9646477699279785


In [64]:
output_checkpointed = output.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_checkpointed[name] = param.grad.data.clone()
    else:
        print(name, "is NONE!")

In [48]:
for k, v in model.named_parameters():
    print(v.grad)

tensor([[ 3.4827e-04, -1.6849e-04,  1.6293e-04],
        [-6.9217e-04,  1.0137e-06, -1.5192e-05],
        [-1.1422e-03, -4.9120e-04,  5.0173e-04],
        [-2.0882e-03,  4.6678e-04,  8.0273e-04],
        [-5.8277e-04,  1.0323e-04, -4.9846e-04],
        [ 7.3378e-04, -1.2949e-03, -4.7639e-04],
        [-2.6608e-04,  4.2600e-04,  8.8583e-04],
        [ 4.6850e-03, -4.5987e-04, -3.0314e-03],
        [-3.4472e-04,  7.9671e-05,  1.2699e-04],
        [-3.2708e-03,  9.7536e-04,  1.4880e-03],
        [-2.8194e-04, -1.3201e-04,  3.0481e-04],
        [ 1.3466e-03, -1.3849e-03, -1.1392e-03],
        [-1.3290e-03, -4.9063e-04,  1.2931e-03],
        [ 9.5799e-04,  2.1863e-04,  1.8548e-04],
        [ 3.1812e-03, -1.8188e-04, -1.8478e-03],
        [-4.5667e-04, -6.9510e-04,  7.0827e-04],
        [-1.4299e-03,  2.7700e-04,  4.7499e-04],
        [-1.9144e-04, -1.1173e-04,  4.4829e-05],
        [ 3.4400e-04,  1.4163e-04, -6.3307e-05],
        [ 3.8257e-04,  2.5238e-05, -5.2217e-04],
        [-1.1863e-03

### Regular Training

In [13]:
for epoch in range(100):
    tic = tt() 
    model.train()
    train_loss = train_gnn(model, train_loader, optimizer, m_configs)
    print('Training loss: {:.4f} in time {}'.format(train_loss, tt() - tic))

    model.eval()
    with torch.no_grad():
        edge_pur, edge_eff, val_loss = evaluate_gnn(model, test_loader, m_configs)
    wandb.log({"val_loss": val_loss, "train_loss": train_loss, "edge_pur": edge_pur, "edge_eff": edge_eff, "lr": optimizer.param_groups[0]['lr']})
    scheduler.step(val_loss)

    save_model(epoch, model, optimizer, scheduler, val_loss, m_configs, 'ResAGNN/'+model_name._name+'.tar')

    print('Epoch: {}, Eff: {:.4f}, Pur: {:.4f}, Loss: {:.4f}, LR: {} in time {}'.format(epoch, edge_eff, edge_pur, val_loss, optimizer.param_groups[0]['lr'], tt()-tic))

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.9.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Training loss: 999.6741 in time 1369.7977874279022
Epoch: 0, Eff: 0.8926, Pur: 0.5942, Loss: 56.1014, LR: 0.001 in time 1403.7224946022034
Training loss: 965.6096 in time 1353.3591628074646
Epoch: 1, Eff: 0.8808, Pur: 0.5465, Loss: 64.1256, LR: 0.001 in time 1392.6192655563354


KeyboardInterrupt: 

In [85]:
model.eval()
with torch.no_grad():
    edge_total_positive, edge_total_true, edge_total_true_positive, edge_total_true_positive_pid = [0]*4
    total_loss = 0

    for i, batch in enumerate(test_loader):

    #         subset_ind = np.random.choice(np.arange(0, batch.e_radius.shape[1]), int(m_configs['test_subset']*batch.e_radius.shape[1]), replace=False)
 
        if m_configs['ratio'] != 0:
            subset_ind = np.random.randint(0, batch.e_radius.shape[1], 10000)
        else:
            subset_ind = np.arange(0, batch.e_radius.shape[1])
        x, e = batch.x.to(device), batch.e_radius[:, subset_ind].to(device)
    #         print("Subset selected")
        output = model(x, e).squeeze()
    #         print("Model applied")

        loss = F.binary_cross_entropy_with_logits(output.cpu(), batch.y[subset_ind])
    #         print("Loss calculated")
        total_loss += loss.item()

        y_pid = batch.pid[batch.e_radius[0]] == batch.pid[batch.e_radius[1]]
        preds = F.sigmoid(output.cpu()) > 0.5
    #         print("Sigmoid applied")
    

        #Edge filter performance
        edge_true = batch.y[subset_ind].sum()
    #         print("Trues calculated")
        edge_true_positive = (batch.y[subset_ind].bool() & preds).sum()
        edge_true_positive_pid = (y_pid[subset_ind] & preds).sum()
    #         print("True-positives calculated")
        edge_positive = preds.sum()
    #         print("Positives calculated")

        edge_total_true_positive += edge_true_positive
        edge_total_true_positive_pid += edge_true_positive_pid
        edge_total_positive += edge_positive
        edge_total_true += edge_true
    #         total_av_adjacent_nhood_size += len(e_adjacent[0]) / len(spatial)

        print("True positive:", edge_true_positive, "True positive (PID):", edge_true_positive_pid, "True:", edge_true, "Positive:", edge_positive)

    edge_eff = edge_total_true_positive / max(edge_total_true, 1)
    edge_pur = edge_total_true_positive.float() / max(edge_total_positive.float(), 1)
    edge_pur_pid = edge_total_true_positive_pid.float() / max(edge_total_positive.float(), 1)

True positive: tensor(33530) True positive (PID): tensor(33538) True: tensor(35476.) Positive: tensor(45499)
True positive: tensor(31225) True positive (PID): tensor(31232) True: tensor(32718.) Positive: tensor(41038)
True positive: tensor(33512) True positive (PID): tensor(33522) True: tensor(35390.) Positive: tensor(45257)
True positive: tensor(35423) True positive (PID): tensor(35425) True: tensor(37574.) Positive: tensor(48595)
True positive: tensor(30360) True positive (PID): tensor(30366) True: tensor(31776.) Positive: tensor(39814)
True positive: tensor(34982) True positive (PID): tensor(34987) True: tensor(36988.) Positive: tensor(47548)
True positive: tensor(36758) True positive (PID): tensor(36759) True: tensor(38900.) Positive: tensor(50623)
True positive: tensor(42887) True positive (PID): tensor(42889) True: tensor(46440.) Positive: tensor(61792)
True positive: tensor(37886) True positive (PID): tensor(37888) True: tensor(40665.) Positive: tensor(53168)
True positive: tens

True positive: tensor(32872) True positive (PID): tensor(32873) True: tensor(34430.) Positive: tensor(43475)
True positive: tensor(47907) True positive (PID): tensor(47912) True: tensor(52722.) Positive: tensor(72195)
True positive: tensor(35971) True positive (PID): tensor(35981) True: tensor(38328.) Positive: tensor(49772)
True positive: tensor(38802) True positive (PID): tensor(38802) True: tensor(41656.) Positive: tensor(55295)
True positive: tensor(36861) True positive (PID): tensor(36865) True: tensor(39356.) Positive: tensor(51002)
True positive: tensor(32459) True positive (PID): tensor(32461) True: tensor(34024.) Positive: tensor(42951)
True positive: tensor(34171) True positive (PID): tensor(34171) True: tensor(36020.) Positive: tensor(45843)
True positive: tensor(39169) True positive (PID): tensor(39172) True: tensor(41806.) Positive: tensor(54832)
True positive: tensor(38713) True positive (PID): tensor(38716) True: tensor(41206.) Positive: tensor(54421)
True positive: tens

In [84]:
print(edge_eff, edge_pur, edge_pur_pid)

tensor(0.9393) tensor(0.7222) tensor(0.7223)


In [65]:
y_pid = batch.pid[batch.e_radius[0]] == batch.pid[batch.e_radius[1]]

In [68]:
y_truth = batch.y.bool()

In [78]:
y_pid.sum().item()/len(y_truth)

0.3880310215372456

In [74]:
y_truth.sum().item()/len(y_truth)

0.2365540407034183

In [75]:
batch

Batch(batch=[24141], e_radius=[2, 202440], embedding=[24141, 64], event_file=[1], hid=[24141], layerless_true_edges=[2, 24513], layers=[24141], layerwise_true_edges=[2, 19272], pid=[24141], x=[24141, 3], y=[202440])

In [77]:
y_truth.sum().item() / (2*batch.layerless_true_edges.shape[1])

0.9767878268673765

In [13]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    6883 KB |  723332 KB |    3941 GB |    3941 GB |
|       from large pool |    4799 KB |  707787 KB |    3842 GB |    3842 GB |
|       from small pool |    2084 KB |   19339 KB |      99 GB |      99 GB |
|---------------------------------------------------------------------------|
| Active memory         |    6883 KB |  723332 KB |    3941 GB |    3941 GB |
|       from large pool |    4799 KB |  707787 KB |    3842 GB |    3842 GB |
|       from small pool |    2084 KB |   19339 KB |      99 GB |      99 GB |
|---------------------------------------------------------------