In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
### Generic imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm
import scipy
import uproot
from tqdm import tqdm
import functools
import pickle
from glob import glob

### ML-related
import tensorflow as tf
import atlas_mpl_style as ampl
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import sonnet as snt

### GNN-related
from graph_nets import blocks
from graph_nets import graphs
from graph_nets import modules
from graph_nets import utils_np
from graph_nets import utils_tf
import networkx as nx

### Custom imports 
from models import *

In [16]:
### GPU Setup
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3" # pick a number between 0 & 3
gpus = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(gpus[0], True)

In [17]:
### Other setup 
pd.set_option('display.max_columns', 200)
pd.set_option('display.max_rows', 20)

params = {'legend.fontsize': 13, 'axes.labelsize': 18}
plt.rcParams.update(params)

SEED = 15
np.random.seed(SEED)
tf.random.set_seed(SEED)

### Load graphs into a single list

In [18]:
files = glob("../graphs/*_pion/*.pkl")
graph_list = [pickle.load(open(file, 'rb')) for file in tqdm(files[:10])]

100%|██████████| 10/10 [00:00<00:00, 60.15it/s]


In [19]:
# train and test sets
train_graphs, test_graphs = train_test_split(graph_list, test_size = 0.2, random_state = SEED)
print("Number of training graph pairs: ", len(train_graphs))
print("Number of testing graph pairs: ", len(test_graphs))

Number of training graph pairs:  8
Number of testing graph pairs:  2


In [70]:
test_graphs[0][1][0].globals

<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[22.377983 ,  2.4487834, -2.0464356]], dtype=float32)>

In [54]:
np.array(test_graphs[0][3][0].n_edge)

array([500], dtype=int32)

# Define GNN model

In [21]:
model = GlobalClassifier()

In [None]:
batch_size = 200
input_signature = get_signature(df, batch_size)

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 2
num_processing_steps_te = 2

learning_rate = 1e-3
optimizer = snt.optimizers.Adam(learning_rate)

# model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
last_iteration = 0
generalization_iteration = 0

logged_iterations = []

losses_tr = []
corrects_tr = []
solveds_tr = []

losses_test = []
corrects_test = []
solveds_test = []

@functools.partial(tf.function, input_signature=input_signature)
def update_step(inputs_tr, targets_tr):
    print("Tracing update_step")
    with tf.GradientTape() as tape:
        outputs_tr = model(inputs_tr, num_processing_steps_tr)
        loss_ops_tr = loss_function_global(targets_tr, outputs_tr)
        loss_op_tr = tf.math.reduce_sum(loss_ops_tr) / tf.constant(num_processing_steps_tr, dtype=tf.float32)

    gradients = tape.gradient(loss_op_tr, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)
    return outputs_tr, loss_op_tr

In [None]:
for epoch in range(10):
    total_loss = 0.
    num_batches = 0

    input_list = []
    target_list = []

    # training dataset
    for data in train_graphs:
        input_tr, target_tr = data
        if input_tr is None:
                continue
        input_list.append(input_tr)
        target_list.append(target_tr)
        if len(input_list) >= batch_size:
            input_tr = utils_tf.concat(input_list, axis=0)
            target_tr = utils_tf.concat(target_list, axis=0)

            current_loss = update_step(input_tr, target_tr)[1].numpy()
            total_loss += current_loss

            num_batches += 1
            input_list = []
            target_list = []
    loss_tr = total_loss / num_batches
    losses_tr.append(loss_tr)

    # testing dataset                         ***** TODO fix this *****
    total_loss_test = 0.
    num_batches_test = 0
    input_list_test = []
    target_list_test = []
    for data in test_graphs:
        input_test, target_test = data
        if input_test is None:
                continue
        input_list_test.append(input_test)
        target_list_test.append(target_test)
        if len(input_list_test) >= batch_size:
            input_test = utils_tf.concat(input_list, axis=0)
            target_test = utils_tf.concat(target_list, axis=0)
            output_test = model(input_test, num_processing_steps_te)
            loss_ops_test = loss_function_global(target_test, output_test)
            total_loss_test += tf.math.reduce_sum(loss_ops_test) / tf.constant(num_processing_steps_tr, dtype=tf.float32)

            num_batches_test += 1
            input_list_test = []
            target_list_test = []
    loss_test = total_loss_test / num_batches_test
    losses_test.append(loss_test)

    print_string = f"Epoch: {epoch}\tloss_tr: {loss_tr}\tloss_test: {loss_test}"
    with open("log2.txt", "a") as log:
        log.write(print_string + "\n")
    print(print_string)
# TODO add a checkpoint > save to disk

In [None]:
plt.plot(losses_tr, label="Train")
plt.plot(losses_test, label="Test")
plt.legend()
plt.title("Losses per epoch");
# plt.savefig("loss_graph.png")