# Graph Networts

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
#import libraries and some constants

import os
import time
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm
import pandas as pd
import tensorflow as tf

pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

import uproot3 as ur
# import atlas_mpl_style as ampl
# ampl.use_atlas_style()

path_prefix = '/global/home/users/mfong/git/LCStudies/'
plotpath = path_prefix + 'classifier/Plots/'
modelpath = path_prefix + 'classifier/Models/'
# %config InlineBackend.figure_format = 'svg'

# metadata
layers = ["EMB1", "EMB2", "EMB3", "TileBar0", "TileBar1", "TileBar2"]
cell_size_phi = [0.098, 0.0245, 0.0245, 0.1, 0.1, 0.1]
cell_size_eta = [0.0031, 0.025, 0.05, 0.1, 0.1, 0.2]
len_phi = [4, 16, 16, 4, 4, 4]
len_eta = [128, 16, 8, 4, 4, 2]

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Xiangyang run from here to get dummy data

In [14]:
# read dummy data
dummy_df = pd.read_csv("dummy_df.csv")
dummy_df

Unnamed: 0,EMB1_0,EMB1_1,EMB1_2,EMB1_3,EMB1_4,EMB1_5,EMB1_6,EMB1_7,EMB1_8,EMB1_9,EMB1_10,EMB1_11,EMB1_12,EMB1_13,EMB1_14,EMB1_15,EMB1_16,EMB1_17,EMB1_18,EMB1_19,EMB1_20,EMB1_21,EMB1_22,EMB1_23,EMB1_24,EMB1_25,EMB1_26,EMB1_27,EMB1_28,EMB1_29,EMB1_30,EMB1_31,EMB1_32,EMB1_33,EMB1_34,EMB1_35,EMB1_36,EMB1_37,EMB1_38,EMB1_39,EMB1_40,EMB1_41,EMB1_42,EMB1_43,EMB1_44,EMB1_45,EMB1_46,EMB1_47,EMB1_48,EMB1_49,...,EMB3_119,EMB3_120,EMB3_121,EMB3_122,EMB3_123,EMB3_124,EMB3_125,EMB3_126,EMB3_127,TileBar0_0,TileBar0_1,TileBar0_2,TileBar0_3,TileBar0_4,TileBar0_5,TileBar0_6,TileBar0_7,TileBar0_8,TileBar0_9,TileBar0_10,TileBar0_11,TileBar0_12,TileBar0_13,TileBar0_14,TileBar0_15,TileBar1_0,TileBar1_1,TileBar1_2,TileBar1_3,TileBar1_4,TileBar1_5,TileBar1_6,TileBar1_7,TileBar1_8,TileBar1_9,TileBar1_10,TileBar1_11,TileBar1_12,TileBar1_13,TileBar1_14,TileBar1_15,TileBar2_0,TileBar2_1,TileBar2_2,TileBar2_3,TileBar2_4,TileBar2_5,TileBar2_6,TileBar2_7,is_p0
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.973252,0.026748,0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.026335,0.000000,0.000000,0.072234,0.764238,0.016248,0.000000,0.002098,0.067689,0.026477,0.001169,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.002202,0.000000,0.000000,0.003275,0.015712,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.024622,0.00000,0.178025,0.219936,0.000000,0.000000,0.118353,0.431438,0.000000,0.000000,0.009038,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,1
1996,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000032,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00191,0.000000,0.000000,0.000000,0.006119,0.036603,0.000000,0.007777,0.002578,0.008802,0.019965,0.000219,0.003530,0.000235,0.000594,0.000055,0.006055,0.015197,0.046428,0.004145,0.00122,0.037687,0.694705,0.016177,0.006726,0.015592,0.019869,0.001952,0.000443,0.001011,0.013743,0.003905,0.0,0.000226,0.000064,0.000000,0
1997,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0
1998,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.000000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0


## Create Graphs

In [2]:
# permutations for doubly connected edges
from itertools import permutations
import functools
import networkx as nx
import sonnet as snt

from graph_nets import blocks

from graph_nets import graphs
from graph_nets import modules
from graph_nets import utils_np
from graph_nets import utils_tf

In [15]:
# event0 = df.loc[0]
event0 = dummy_df.loc[0]
event0

EMB1_0        0.0
EMB1_1        0.0
EMB1_2        0.0
EMB1_3        0.0
EMB1_4        0.0
             ... 
TileBar2_4    0.0
TileBar2_5    0.0
TileBar2_6    0.0
TileBar2_7    0.0
is_p0         0.0
Name: 0, Length: 937, dtype: float64

In [3]:
def make_fully_connected_edges(nodes):
    """
    returns a list of tuples with (sender_node, reciever_node) for a fully connected graph
    ex: [(1,2), (2,1), (0,1)]
    """
    n_nodes = len(nodes)
    return list(permutations(range(n_nodes), 2))

In [4]:
def make_graph(event):
     
    MIN_VALUE = 0.0
    
    
    col_names = list(event.index)
    nodes = [[cell] for cell in event[col_names][event[col_names] > MIN_VALUE]]
    if len(nodes) < 1:
        return (None, None)
    
    # since the last column is the classifier index
    # remove that in the nodes.
    nodes = np.array(nodes, dtype=np.float32)
    
    
    solution = nodes[-1]
    nodes = nodes[:-1]
    n_nodes = len(nodes)

    
    edge_endpoints = make_fully_connected_edges(nodes)
    senders = np.array([x[0] for x in edge_endpoints])
    receivers = np.array([x[1] for x in edge_endpoints])
    n_edges = len(edge_endpoints)
    edges = np.expand_dims(np.array([0.0]*n_edges, dtype=np.float32), axis=1)

    
    input_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": np.array([n_nodes], dtype=np.float32)
    }
    target_datadict = {
        "n_node": n_nodes,
        "n_edge": n_edges,
        "nodes": nodes,
        "edges": edges,
        "senders": senders,
        "receivers": receivers,
        "globals": solution
    }
    input_graph = utils_tf.data_dicts_to_graphs_tuple([input_datadict])
    target_graph = utils_tf.data_dicts_to_graphs_tuple([target_datadict])
    
    return (input_graph, target_graph)

In [5]:
def print_graphs_tuple(g, data=True):
    for field_name in graphs.ALL_FIELDS:
        per_replica_sample = getattr(g, field_name)
        if per_replica_sample is None:
            print(field_name, "EMPTY")
        else:
            print(field_name, "has shape", per_replica_sample.shape)
            if data and  field_name != "edges":
                print(per_replica_sample)

In [16]:
input_graph, target_graph = make_graph(event0)

## Graph net

In [22]:
# Need the newest dev version of graph_nets (see https://github.com/deepmind/graph_nets/issues/139)
# as of 3/25/2021


# !pip install git+git://github.com/deepmind/graph_nets.git

In [6]:
NUM_LAYERS = 2
def make_mlp_model():
  """Instantiates a new MLP, followed by LayerNorm.

  The parameters of each new MLP are not shared with others generated by
  this function.

  Returns:
    A Sonnet module which contains the MLP and LayerNorm.
  """
  # the activation function choices:
  # swish, relu, relu6, leaky_relu
  return snt.Sequential([
      snt.nets.MLP([128, 64]*NUM_LAYERS,
                    activation=tf.nn.relu,
                    activate_final=True, 
                  #  dropout_rate=DROPOUT_RATE
        ),
      snt.LayerNorm(axis=-1, create_scale=True, create_offset=False)
  ])

In [7]:
class MLPGraphNetwork(snt.Module):
    """GraphIndependent with MLP edge, node, and global models."""
    def __init__(self, name="MLPGraphNetwork"):
        super(MLPGraphNetwork, self).__init__(name=name)
        self._network = modules.GraphNetwork(
            edge_model_fn=make_mlp_model,
            node_model_fn=make_mlp_model,
            global_model_fn=make_mlp_model
            )

    def __call__(self, inputs,
            edge_model_kwargs=None,
            node_model_kwargs=None,
            global_model_kwargs=None):
        return self._network(inputs,
                      edge_model_kwargs=edge_model_kwargs,
                      node_model_kwargs=node_model_kwargs,
                      global_model_kwargs=global_model_kwargs)

In [8]:
LATENT_SIZE = 128

class GlobalClassifierNoEdgeInfo(snt.Module):

    def __init__(self, name="GlobalClassifierNoEdgeInfo"):
        super(GlobalClassifierNoEdgeInfo, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(
            edge_model_fn=make_mlp_model,
            use_edges=False,
            use_receiver_nodes=True,
            use_sender_nodes=True,
            use_globals=False,
            name='edge_encoder_block')

        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block'
        )

        self._global_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
        )

        self._core = MLPGraphNetwork()
        # Transforms the outputs into appropriate shapes.
        global_output_size = 1
        global_fn =lambda: snt.Sequential([
            snt.nets.MLP([LATENT_SIZE, global_output_size],
                         name='global_output'), tf.sigmoid])

        self._output_transform = modules.GraphIndependent(None, None, global_fn)

    def __call__(self, input_op, num_processing_steps):
        latent = self._global_block(self._edge_block(self._node_encoder_block(input_op)))
        latent0 = latent

        output_ops = []
        for _ in range(num_processing_steps):
            core_input = utils_tf.concat([latent0, latent], axis=1)
            latent = self._core(core_input)
            output_ops.append(self._output_transform(latent))

        return output_ops

In [17]:
model = GlobalClassifierNoEdgeInfo()

In [18]:
output_graphs = model(input_graph, 10)

In [19]:
[x.globals for x in output_graphs]

[<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.36057603]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.40430257]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.43811432]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.42426237]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.43808976]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.41427603]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.41621685]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.41311935]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.41676453]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.41221336]], dtype=float32)>]

In [20]:
# Loss function:

class GlobalLoss:
    def __init__(self, real_global_weight, fake_global_weight):
        self.w_global_real = real_global_weight
        self.w_global_fake = fake_global_weight

    def __call__(self, target_op, output_ops):
        global_weights = target_op.globals * self.w_global_real \
            + (1 - target_op.globals) * self.w_global_fake
        
        print(global_weights)
        
        loss_ops = [
            tf.compat.v1.losses.log_loss(target_op.globals, output_op.globals, weights=global_weights)
            for output_op in output_ops
        ]
        return tf.stack(loss_ops)

In [21]:
loss_function_global = GlobalLoss(real_global_weight = 1.0, fake_global_weight = 1.0)

In [22]:
loss_function_global(target_graph, output_graphs)

tf.Tensor([[1.]], shape=(1, 1), dtype=float32)


<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.45646122, 0.5242964 , 0.5804846 , 0.5570454 , 0.58044255,
       0.54051274, 0.5437024 , 0.5386171 , 0.5446045 , 0.53713506],
      dtype=float32)>

## Model Training

In [23]:
# TODO modify acc function

def compute_accuracy(target, output):
    """Calculate model accuracy.

    Returns the number of correctly predicted links and the number
    of completely solved list sorts (100% correct predictions).

    Args:
    target: A `graphs.GraphsTuple` that contains the target graph.
    output: A `graphs.GraphsTuple` that contains the output graph.

    Returns:
    correct: A `float` fraction of correctly labeled nodes/edges.
    solved: A `float` fraction of graphs that are completely correctly labeled.
    """
    tdds = utils_np.graphs_tuple_to_data_dicts(target)
    odds = utils_np.graphs_tuple_to_data_dicts(output)
    cs = []
    ss = []
    for td, od in zip(tdds, odds):
        num_elements = td["nodes"].shape[0]
        xn = np.argmax(td["nodes"], axis=-1)
        yn = np.argmax(od["nodes"], axis=-1)

        xe = np.reshape(
            np.argmax(
                np.reshape(td["edges"], (num_elements, num_elements, 2)), axis=-1),
            (-1,))
        ye = np.reshape(
            np.argmax(
                np.reshape(od["edges"], (num_elements, num_elements, 2)), axis=-1),
            (-1,))
        c = np.concatenate((xn == yn, xe == ye), axis=0)
        s = np.all(c)
        cs.append(c)
        ss.append(s)
    correct = np.mean(np.concatenate(cs, axis=0))
    solved = np.mean(np.stack(ss))
    return correct, solved

In [24]:
def get_signature(dataset, batch_size):
    """
    Get signature of inputs for the training loop.
    The signature is used by the tf.function
    """

    input_list = []
    target_list = []
    for _, data in dataset.iterrows():
        dd = make_graph(data)
        if dd[0] is not None:
            input_list.append(dd[0])
            target_list.append(dd[1])
            
        if len(input_list) == batch_size:
            break

    inputs = utils_tf.concat(input_list, axis=0)
    targets = utils_tf.concat(target_list, axis=0)
    input_signature = (
      utils_tf.specs_from_graphs_tuple(inputs),
      utils_tf.specs_from_graphs_tuple(targets)
    )

    return input_signature

In [25]:
batch_size = 2
# the signature has to include the batch size

input_signature = get_signature(dummy_df, batch_size)

# Model parameters.
# Number of processing (message-passing) steps.
num_processing_steps_tr = 10
num_processing_steps_ge = 10


learning_rate = 1e-3
optimizer = snt.optimizers.Adam(learning_rate)


# model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)
last_iteration = 0
generalization_iteration = 0

logged_iterations = []
losses_tr = []
corrects_tr = []
solveds_tr = []
losses_ge = []
corrects_ge = []
solveds_ge = []


@functools.partial(tf.function, input_signature=input_signature)
def update_step(inputs_tr, targets_tr):
    print("Tracing update_step")
    with tf.GradientTape() as tape:
        outputs_tr = model(inputs_tr, num_processing_steps_tr)
        loss_ops_tr = loss_function_global(targets_tr, outputs_tr)
        loss_op_tr = tf.math.reduce_sum(loss_ops_tr) / tf.constant(num_processing_steps_tr, dtype=tf.float32)

    gradients = tape.gradient(loss_op_tr, model.trainable_variables)
    optimizer.apply(gradients, model.trainable_variables)
    return outputs_tr, loss_op_tr

In [26]:
from sklearn.model_selection import train_test_split

# train and generalization df
# df_train, df_test = train_test_split(df, test_size = 0.2, random_state = 42)
df_train, df_test = train_test_split(dummy_df, test_size = 0.2, random_state = 42)

In [27]:
%%time
# TODO this is very slow
# make graphs for each event
train_graphs = [make_graph(event) for _, event in df_train.iterrows()]
test_graphs = [make_graph(event) for _, event in df_test.iterrows()]

CPU times: user 22.2 s, sys: 797 ms, total: 23 s
Wall time: 21.6 s


In [28]:
train_graphs = [x for x in train_graphs if x[0] is not None]
test_graphs = [x for x in test_graphs if x[0] is not None]

In [29]:
# save train_graphs and test_graphs objects to file, it takes too long to make
import pickle

def save_object(obj, filename):
    with open(filename, "wb") as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)
"""
save_object(train_graphs, "Temp/train_graphs.pkl")
save_object(train_graphs, "Temp/test_graphs.pkl")
"""

'\nsave_object(train_graphs, "Temp/train_graphs.pkl")\nsave_object(train_graphs, "Temp/test_graphs.pkl")\n'

In [30]:
def loop_dataset(datasets, batch_size):
    if batch_size > 0:
        in_list = []
        target_list = []
        for dataset in datasets:
            inputs_tr, targets_tr = dataset
            if inputs_tr is None:
                continue
            in_list.append(inputs_tr)
            target_list.append(targets_tr)
            if len(in_list) == batch_size:
                inputs_tr = utils_tf.concat(in_list, axis=0)
                targets_tr = utils_tf.concat(target_list, axis=0)
                yield (inputs_tr, targets_tr)
                in_list = []
                target_list = []
    else:
        for dataset in datasets:
            if dataset is None:
                continue
            yield dataset

In [31]:
training_data = loop_dataset(train_graphs, batch_size)

In [32]:
data = next(training_data)

In [33]:
len(data)

2

In [34]:
input_tr, target_tr = data

In [35]:
update_step(input_tr, target_tr)[1].numpy()

Tracing update_step
Tensor("add:0", shape=(2, 1), dtype=float32)




Tracing update_step
Tensor("add:0", shape=(2, 1), dtype=float32)


0.7798856

In [37]:
# How much time between logging and printing the current results.
log_every_seconds = 10

start_time = time.time()
last_log_time = start_time

# Data / training parameters.
num_training_iterations = 20

# code for training loop:
# https://github.com/xju2/root_gnn/blob/tf2/root_gnn/scripts/train_classifier
for epoch in range(2):
    total_loss = 0.
    num_batches = 0
    
    for _ in range(num_training_iterations):
        input_tr, target_tr = next(training_data)
        total_loss += update_step(input_tr, target_tr)[1].numpy()
        num_batches += 1
        
    loss_tr = total_loss/num_batches
    print("Loss value: ", loss_tr)

Loss value:  0.5761519506573677
Loss value:  0.6627717569470406
