## **Edge-Supervised GraphSAGE**

In [None]:
%pip uninstall -y jax jaxlib
%pip install -U "tensorflow==2.17.*"

In [None]:
import tensorflow as tf, sys
print("TF", tf.__version__, "Python", sys.version)
tf.compat.v1.disable_eager_execution()

In [20]:
# ==== SHIM TF1->TF2 ====
import types, tensorflow as tf
tf1 = tf.compat.v1
tf1.disable_eager_execution()

tf.app = tf1.app
tf.variable_scope = tf1.variable_scope
tf.get_variable   = tf1.get_variable
tf.summary        = tf1.summary
tf.train          = tf1.train
tf.set_random_seed= tf1.set_random_seed
tf.global_variables_initializer = tf1.global_variables_initializer
tf.reset_default_graph         = tf1.reset_default_graph

tf.random_uniform = tf1.random_uniform
tf.random_shuffle = tf1.random_shuffle
tf.nn.dropout     = tf1.nn.dropout
tf.placeholder    = tf1.placeholder
tf.Session        = tf1.Session

def _xavier_init():
    return tf1.keras.initializers.glorot_uniform()

tf.contrib = types.SimpleNamespace(
    layers = types.SimpleNamespace(
        xavier_initializer = _xavier_init,
        l2_regularizer    = tf.keras.regularizers.l2,
    )
)

In [26]:
import os, json, time, numpy as np, pandas as pd, matplotlib.pyplot as plt
from networkx.readwrite import json_graph
import networkx as nx
from datetime import datetime
from zoneinfo import ZoneInfo
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, confusion_matrix, precision_recall_curve, roc_curve, precision_score, recall_score, accuracy_score
import tensorflow as tf

### **GraphSage**

#### Inits

In [22]:
import tensorflow as tf
import numpy as np

# DISCLAIMER:
# Parts of this code file are derived from
# https://github.com/tkipf/gcn
# which is under an identical MIT license as GraphSAGE

def uniform(shape, scale=0.05, name=None):
    """Uniform init."""
    initial = tf.random_uniform(shape, minval=-scale, maxval=scale, dtype=tf.float32)
    return tf.Variable(initial, name=name)


def glorot(shape, name=None):
    """Glorot & Bengio (AISTATS 2010) init."""
    init_range = np.sqrt(6.0/(shape[0]+shape[1]))
    initial = tf.random_uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32)
    return tf.Variable(initial, name=name)


def zeros(shape, name=None):
    """All zeros."""
    initial = tf.zeros(shape, dtype=tf.float32)
    return tf.Variable(initial, name=name)

def ones(shape, name=None):
    """All ones."""
    initial = tf.ones(shape, dtype=tf.float32)
    return tf.Variable(initial, name=name)

#### Layers

In [23]:
from __future__ import division
from __future__ import print_function

flags = tf.app.flags
FLAGS = flags.FLAGS

# DISCLAIMER:
# Boilerplate parts of this code file were originally forked from
# https://github.com/tkipf/gcn
# which itself was very inspired by the keras package

# global unique layer ID dictionary for layer name assignment
_LAYER_UIDS = {}

def get_layer_uid(layer_name=''):
    """Helper function, assigns unique layer IDs."""
    if layer_name not in _LAYER_UIDS:
        _LAYER_UIDS[layer_name] = 1
        return 1
    else:
        _LAYER_UIDS[layer_name] += 1
        return _LAYER_UIDS[layer_name]

class Layer(object):
    """Base layer class. Defines basic API for all layer objects.
    Implementation inspired by keras (http://keras.io).
    # Properties
        name: String, defines the variable scope of the layer.
        logging: Boolean, switches Tensorflow histogram logging on/off

    # Methods
        _call(inputs): Defines computation graph of layer
            (i.e. takes input, returns output)
        __call__(inputs): Wrapper for _call()
        _log_vars(): Log all variables
    """

    def __init__(self, **kwargs):
        allowed_kwargs = {'name', 'logging', 'model_size'}
        for kwarg in kwargs.keys():
            assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
        name = kwargs.get('name')
        if not name:
            layer = self.__class__.__name__.lower()
            name = layer + '_' + str(get_layer_uid(layer))
        self.name = name
        self.vars = {}
        logging = kwargs.get('logging', False)
        self.logging = logging
        self.sparse_inputs = False

    def _call(self, inputs):
        return inputs

    def __call__(self, inputs):
        with tf.name_scope(self.name):
            if self.logging and not self.sparse_inputs:
                tf.summary.histogram(self.name + '/inputs', inputs)
            outputs = self._call(inputs)
            if self.logging:
                tf.summary.histogram(self.name + '/outputs', outputs)
            return outputs

    def _log_vars(self):
        for var in self.vars:
            tf.summary.histogram(self.name + '/vars/' + var, self.vars[var])


class Dense(Layer):
    """Dense layer."""
    def __init__(self, input_dim, output_dim, dropout=0.,
                 act=tf.nn.relu, placeholders=None, bias=True, featureless=False,
                 sparse_inputs=False, **kwargs):
        super(Dense, self).__init__(**kwargs)

        self.dropout = dropout

        self.act = act
        self.featureless = featureless
        self.bias = bias
        self.input_dim = input_dim
        self.output_dim = output_dim

        # helper variable for sparse dropout
        self.sparse_inputs = sparse_inputs
        if sparse_inputs:
            self.num_features_nonzero = placeholders['num_features_nonzero']

        with tf.variable_scope(self.name + '_vars'):
            self.vars['weights'] = tf.get_variable('weights', shape=(input_dim, output_dim),
                                         dtype=tf.float32,
                                         initializer=tf.contrib.layers.xavier_initializer(),
                                         regularizer=tf.contrib.layers.l2_regularizer(FLAGS.weight_decay))
            if self.bias:
                self.vars['bias'] = zeros([output_dim], name='bias')

        if self.logging:
            self._log_vars()

    def _call(self, inputs):
        x = inputs

        x = tf.nn.dropout(x, 1-self.dropout)

        # transform
        output = tf.matmul(x, self.vars['weights'])

        # bias
        if self.bias:
            output += self.vars['bias']

        return self.act(output)

#### Aggregators

In [24]:
class MeanAggregator(Layer):
    """
    Aggregates via mean followed by matmul and non-linearity.
    """

    def __init__(self, input_dim, output_dim, neigh_input_dim=None,
            dropout=0., bias=False, act=tf.nn.relu,
            name=None, concat=False, **kwargs):
        super(MeanAggregator, self).__init__(**kwargs)

        self.dropout = dropout
        self.bias = bias
        self.act = act
        self.concat = concat

        if neigh_input_dim is None:
            neigh_input_dim = input_dim

        if name is not None:
            name = '/' + name
        else:
            name = ''

        with tf.variable_scope(self.name + name + '_vars'):
            self.vars['neigh_weights'] = glorot([neigh_input_dim, output_dim],
                                                        name='neigh_weights')
            self.vars['self_weights'] = glorot([input_dim, output_dim],
                                                        name='self_weights')
            if self.bias:
                self.vars['bias'] = zeros([self.output_dim], name='bias')

        if self.logging:
            self._log_vars()

        self.input_dim = input_dim
        self.output_dim = output_dim

    def _call(self, inputs):
        self_vecs, neigh_vecs = inputs

        neigh_vecs = tf.nn.dropout(neigh_vecs, 1-self.dropout)
        self_vecs = tf.nn.dropout(self_vecs, 1-self.dropout)
        neigh_means = tf.reduce_mean(neigh_vecs, axis=1)

        # [nodes] x [out_dim]
        from_neighs = tf.matmul(neigh_means, self.vars['neigh_weights'])

        from_self = tf.matmul(self_vecs, self.vars["self_weights"])

        if not self.concat:
            output = tf.add_n([from_self, from_neighs])
        else:
            output = tf.concat([from_self, from_neighs], axis=1)

        # bias
        if self.bias:
            output += self.vars['bias']

        return self.act(output)

class GCNAggregator(Layer):
    """
    Aggregates via mean followed by matmul and non-linearity.
    Same matmul parameters are used self vector and neighbor vectors.
    """

    def __init__(self, input_dim, output_dim, neigh_input_dim=None,
            dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
        super(GCNAggregator, self).__init__(**kwargs)

        self.dropout = dropout
        self.bias = bias
        self.act = act
        self.concat = concat

        if neigh_input_dim is None:
            neigh_input_dim = input_dim

        if name is not None:
            name = '/' + name
        else:
            name = ''

        with tf.variable_scope(self.name + name + '_vars'):
            self.vars['weights'] = glorot([neigh_input_dim, output_dim],
                                                        name='neigh_weights')
            if self.bias:
                self.vars['bias'] = zeros([self.output_dim], name='bias')

        if self.logging:
            self._log_vars()

        self.input_dim = input_dim
        self.output_dim = output_dim

    def _call(self, inputs):
        self_vecs, neigh_vecs = inputs

        neigh_vecs = tf.nn.dropout(neigh_vecs, 1-self.dropout)
        self_vecs = tf.nn.dropout(self_vecs, 1-self.dropout)
        means = tf.reduce_mean(tf.concat([neigh_vecs,
            tf.expand_dims(self_vecs, axis=1)], axis=1), axis=1)

        # [nodes] x [out_dim]
        output = tf.matmul(means, self.vars['weights'])

        # bias
        if self.bias:
            output += self.vars['bias']

        return self.act(output)


class MaxPoolingAggregator(Layer):
    """ Aggregates via max-pooling over MLP functions.
    """
    def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
            dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
        super(MaxPoolingAggregator, self).__init__(**kwargs)

        self.dropout = dropout
        self.bias = bias
        self.act = act
        self.concat = concat

        if neigh_input_dim is None:
            neigh_input_dim = input_dim

        if name is not None:
            name = '/' + name
        else:
            name = ''

        if model_size == "small":
            hidden_dim = self.hidden_dim = 512
        elif model_size == "big":
            hidden_dim = self.hidden_dim = 1024

        self.mlp_layers = []
        self.mlp_layers.append(Dense(input_dim=neigh_input_dim,
                                 output_dim=hidden_dim,
                                 act=tf.nn.relu,
                                 dropout=dropout,
                                 sparse_inputs=False,
                                 logging=self.logging))

        with tf.variable_scope(self.name + name + '_vars'):
            self.vars['neigh_weights'] = glorot([hidden_dim, output_dim],
                                                        name='neigh_weights')

            self.vars['self_weights'] = glorot([input_dim, output_dim],
                                                        name='self_weights')
            if self.bias:
                self.vars['bias'] = zeros([self.output_dim], name='bias')

        if self.logging:
            self._log_vars()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.neigh_input_dim = neigh_input_dim

    def _call(self, inputs):
        self_vecs, neigh_vecs = inputs
        neigh_h = neigh_vecs

        dims = tf.shape(neigh_h)
        batch_size = dims[0]
        num_neighbors = dims[1]
        # [nodes * sampled neighbors] x [hidden_dim]
        h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim))

        for l in self.mlp_layers:
            h_reshaped = l(h_reshaped)
        neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim))
        neigh_h = tf.reduce_max(neigh_h, axis=1)

        from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
        from_self = tf.matmul(self_vecs, self.vars["self_weights"])

        if not self.concat:
            output = tf.add_n([from_self, from_neighs])
        else:
            output = tf.concat([from_self, from_neighs], axis=1)

        # bias
        if self.bias:
            output += self.vars['bias']

        return self.act(output)

class MeanPoolingAggregator(Layer):
    """ Aggregates via mean-pooling over MLP functions.
    """
    def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
            dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
        super(MeanPoolingAggregator, self).__init__(**kwargs)

        self.dropout = dropout
        self.bias = bias
        self.act = act
        self.concat = concat

        if neigh_input_dim is None:
            neigh_input_dim = input_dim

        if name is not None:
            name = '/' + name
        else:
            name = ''

        if model_size == "small":
            hidden_dim = self.hidden_dim = 512
        elif model_size == "big":
            hidden_dim = self.hidden_dim = 1024

        self.mlp_layers = []
        self.mlp_layers.append(Dense(input_dim=neigh_input_dim,
                                 output_dim=hidden_dim,
                                 act=tf.nn.relu,
                                 dropout=dropout,
                                 sparse_inputs=False,
                                 logging=self.logging))

        with tf.variable_scope(self.name + name + '_vars'):
            self.vars['neigh_weights'] = glorot([hidden_dim, output_dim],
                                                        name='neigh_weights')

            self.vars['self_weights'] = glorot([input_dim, output_dim],
                                                        name='self_weights')
            if self.bias:
                self.vars['bias'] = zeros([self.output_dim], name='bias')

        if self.logging:
            self._log_vars()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.neigh_input_dim = neigh_input_dim

    def _call(self, inputs):
        self_vecs, neigh_vecs = inputs
        neigh_h = neigh_vecs

        dims = tf.shape(neigh_h)
        batch_size = dims[0]
        num_neighbors = dims[1]
        # [nodes * sampled neighbors] x [hidden_dim]
        h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim))

        for l in self.mlp_layers:
            h_reshaped = l(h_reshaped)
        neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim))
        neigh_h = tf.reduce_mean(neigh_h, axis=1)

        from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
        from_self = tf.matmul(self_vecs, self.vars["self_weights"])

        if not self.concat:
            output = tf.add_n([from_self, from_neighs])
        else:
            output = tf.concat([from_self, from_neighs], axis=1)

        # bias
        if self.bias:
            output += self.vars['bias']

        return self.act(output)


class TwoMaxLayerPoolingAggregator(Layer):
    """ Aggregates via pooling over two MLP functions.
    """
    def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
            dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
        super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs)

        self.dropout = dropout
        self.bias = bias
        self.act = act
        self.concat = concat

        if neigh_input_dim is None:
            neigh_input_dim = input_dim

        if name is not None:
            name = '/' + name
        else:
            name = ''

        if model_size == "small":
            hidden_dim_1 = self.hidden_dim_1 = 512
            hidden_dim_2 = self.hidden_dim_2 = 256
        elif model_size == "big":
            hidden_dim_1 = self.hidden_dim_1 = 1024
            hidden_dim_2 = self.hidden_dim_2 = 512

        self.mlp_layers = []
        self.mlp_layers.append(Dense(input_dim=neigh_input_dim,
                                 output_dim=hidden_dim_1,
                                 act=tf.nn.relu,
                                 dropout=dropout,
                                 sparse_inputs=False,
                                 logging=self.logging))
        self.mlp_layers.append(Dense(input_dim=hidden_dim_1,
                                 output_dim=hidden_dim_2,
                                 act=tf.nn.relu,
                                 dropout=dropout,
                                 sparse_inputs=False,
                                 logging=self.logging))


        with tf.variable_scope(self.name + name + '_vars'):
            self.vars['neigh_weights'] = glorot([hidden_dim_2, output_dim],
                                                        name='neigh_weights')

            self.vars['self_weights'] = glorot([input_dim, output_dim],
                                                        name='self_weights')
            if self.bias:
                self.vars['bias'] = zeros([self.output_dim], name='bias')

        if self.logging:
            self._log_vars()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.neigh_input_dim = neigh_input_dim

    def _call(self, inputs):
        self_vecs, neigh_vecs = inputs
        neigh_h = neigh_vecs

        dims = tf.shape(neigh_h)
        batch_size = dims[0]
        num_neighbors = dims[1]
        # [nodes * sampled neighbors] x [hidden_dim]
        h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim))

        for l in self.mlp_layers:
            h_reshaped = l(h_reshaped)
        neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim_2))
        neigh_h = tf.reduce_max(neigh_h, axis=1)

        from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
        from_self = tf.matmul(self_vecs, self.vars["self_weights"])

        if not self.concat:
            output = tf.add_n([from_self, from_neighs])
        else:
            output = tf.concat([from_self, from_neighs], axis=1)

        # bias
        if self.bias:
            output += self.vars['bias']

        return self.act(output)

class SeqAggregator(Layer):
    """ Aggregates via a standard LSTM.
    """
    def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
            dropout=0., bias=False, act=tf.nn.relu, name=None,  concat=False, **kwargs):
        super(SeqAggregator, self).__init__(**kwargs)

        self.dropout = dropout
        self.bias = bias
        self.act = act
        self.concat = concat

        if neigh_input_dim is None:
            neigh_input_dim = input_dim

        if name is not None:
            name = '/' + name
        else:
            name = ''

        if model_size == "small":
            hidden_dim = self.hidden_dim = 128
        elif model_size == "big":
            hidden_dim = self.hidden_dim = 256

        with tf.variable_scope(self.name + name + '_vars'):
            self.vars['neigh_weights'] = glorot([hidden_dim, output_dim],
                                                        name='neigh_weights')

            self.vars['self_weights'] = glorot([input_dim, output_dim],
                                                        name='self_weights')
            if self.bias:
                self.vars['bias'] = zeros([self.output_dim], name='bias')

        if self.logging:
            self._log_vars()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.neigh_input_dim = neigh_input_dim
        # self.cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_dim)
        try:
            self.cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(self.hidden_dim)
            self._use_dynamic_rnn = True
        except Exception:
            self.cell = tf.keras.layers.LSTMCell(self.hidden_dim)
            self._use_dynamic_rnn = False

    def _call(self, inputs):
      self_vecs, neigh_vecs = inputs

      used   = tf.sign(tf.reduce_max(tf.abs(neigh_vecs), axis=2))
      lengths = tf.cast(tf.maximum(tf.reduce_sum(used, axis=1), 1.0), tf.int32)
      batch_size = tf.shape(neigh_vecs)[0]

      if getattr(self, "_use_dynamic_rnn", False):
          initial_state = self.cell.zero_state(batch_size, tf.float32)
          with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
              rnn_outputs, _ = tf.compat.v1.nn.dynamic_rnn(
                  self.cell,
                  neigh_vecs,
                  initial_state=initial_state,
                  dtype=tf.float32,
                  time_major=False,
                  sequence_length=lengths
              )
      else:
          mask_bool = tf.cast(used, tf.bool)
          rnn_layer = tf.keras.layers.RNN(
              self.cell, return_sequences=True, return_state=False, name=self.name
          )
          rnn_outputs = rnn_layer(neigh_vecs, mask=mask_bool, training=False)

      max_len  = tf.shape(rnn_outputs)[1]
      out_size = tf.shape(rnn_outputs)[2]
      idx = tf.range(batch_size) * max_len + (lengths - 1)
      flat = tf.reshape(rnn_outputs, [-1, out_size])
      neigh_h = tf.gather(flat, idx)

      from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
      from_self   = tf.matmul(self_vecs, self.vars['self_weights'])

      if not self.concat:
          output = tf.add_n([from_self, from_neighs])
      else:
          output = tf.concat([from_self, from_neighs], axis=1)

      if self.bias:
          output += self.vars['bias']

      return self.act(output)

#### Neighbor Sampler

In [25]:
from __future__ import division
from __future__ import print_function

flags = tf.app.flags
FLAGS = flags.FLAGS

"""
Classes that are used to sample node neighborhoods
"""

class UniformNeighborSampler(Layer):
    """
    Uniformly samples neighbors.
    Assumes that adj lists are padded with random re-sampling
    """
    def __init__(self, adj_info, **kwargs):
        super(UniformNeighborSampler, self).__init__(**kwargs)
        self.adj_info = adj_info

    def _call(self, inputs):
        ids, num_samples = inputs
        adj_lists = tf.nn.embedding_lookup(self.adj_info, ids)
        adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))
        adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples])
        return adj_lists

### **Edge Supervised GraphSage**

In [9]:
tf1 = tf.compat.v1; tf1.disable_eager_execution()
xavier_init = tf1.keras.initializers.glorot_uniform()
zeros_init  = tf1.zeros_initializer()

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

DATA_DIR = ''

In [None]:
PROC_DIR   = os.path.join(DATA_DIR, 'processed')
OUT_PREFIX = os.path.join(PROC_DIR, 'graphsage_edges')

In [None]:
AGGREGATOR = "mean"     # "mean","gcn","maxpool","meanpool","seq"
DIM_1, DIM_2 = 128, 128
SAMPLES_1, SAMPLES_2 = 25, 10
BATCH_SIZE = 1024
DROPOUT = 0.1
LR = 1e-3
EPOCHS = 2
MAX_DEGREE = 128
CP_EMB_DIM = 8
SEED = 123
np.random.seed(SEED); tf1.set_random_seed(SEED)

In [None]:
def load_graph_base(prefix):
    G_data = json.load(open(prefix + "-G.json"))
    G = json_graph.node_link_graph(G_data)
    id_map = json.load(open(prefix + "-id_map.json"))
    id_map = {str(k): int(v) for k,v in id_map.items()}
    feats = np.load(prefix + "-feats.npy")
    feats = np.vstack([feats, np.zeros((feats.shape[1],), dtype=feats.dtype)])
    n = len(id_map)
    adj = (n+1)*np.ones((n+1, MAX_DEGREE), dtype=np.int32)
    deg = np.zeros((n,), dtype=np.int32)
    for node in G.nodes():
        u = id_map[str(node)]
        neigh = [id_map[str(v)] for v in G.neighbors(node)]
        deg[u] = len(neigh)
        if len(neigh)==0: continue
        if len(neigh)>MAX_DEGREE: neigh = np.random.choice(neigh, MAX_DEGREE, replace=False)
        elif len(neigh)<MAX_DEGREE: neigh = np.random.choice(neigh, MAX_DEGREE, replace=True)
        adj[u,:] = np.array(neigh, dtype=np.int32)
    return G, feats, id_map, adj, deg

def load_edge_split(prefix, split):
    edges = np.load(f"{prefix}-{split}_edges.npy").astype(np.int32)
    efeat = np.load(f"{prefix}-{split}_edge_feats.npy").astype(np.float32)
    labels= np.load(f"{prefix}-{split}_edge_labels.npy").astype(np.int32)
    cpb   = np.load(f"{prefix}-{split}_cp_bucket.npy").astype(np.int32)
    txids = np.load(f"{prefix}-{split}_txids.npy")
    return edges, efeat, labels, cpb, txids

class EdgeLoader:
    def __init__(self, prefix):
        self.train = load_edge_split(prefix, "train")
        self.valid = load_edge_split(prefix, "valid")
        self.test  = load_edge_split(prefix, "test")
        self.F = self.train[1].shape[1]
        y = self.train[2]; self.pos_rate = (y==1).mean() if len(y) else 0.0
    def iter_batches(self, batch_size, split="train", shuffle=True):
        edges, efeat, labels, cpb, _ = getattr(self, split)
        idx = np.arange(len(labels))
        if shuffle: np.random.shuffle(idx)
        for i in range(0, len(idx), batch_size):
            sel = idx[i:i+batch_size]
            yield edges[sel,0], edges[sel,1], efeat[sel], labels[sel], cpb[sel]

class EdgeSupGraphSAGE:
    def __init__(self, feats, adj, deg, aggregator="mean",
                 dim_1=128, dim_2=128, samples_1=25, samples_2=10,
                 dropout=0.0, lr=1e-3, cp_n_buckets=4096, cp_emb_dim=8, concat=True):
        self.feats_np, self.adj_np, self.deg_np = feats, adj, deg
        self.dropout_rate, self.lr, self.concat = dropout, lr, concat

        self.batch_size_ph = tf1.placeholder(tf.int32, shape=(), name="batch_size")
        self.batch_u = tf1.placeholder(tf.int32, shape=[None], name="batch_src")
        self.batch_v = tf1.placeholder(tf.int32, shape=[None], name="batch_dst")
        self.edge_feats = tf1.placeholder(tf.float32, shape=[None, feats.shape[1]*0+25], name="edge_feats")  # dim set at feed
        self.cp_idx     = tf1.placeholder(tf.int32,   shape=[None], name="cp_bucket")
        self.labels     = tf1.placeholder(tf.float32, shape=[None], name="labels")
        self.pos_weight = tf1.placeholder(tf.float32, shape=(), name="pos_weight")
        self.dropout_ph = tf1.placeholder_with_default(0.0, shape=(), name="dropout")

        self.adj   = tf1.Variable(self.adj_np, trainable=False, dtype=tf.int32, name="adj_info")
        self.feats = tf1.Variable(tf.constant(self.feats_np, dtype=tf.float32), trainable=False, name="node_feats")
        self.degrees = tf.constant(self.deg_np, dtype=tf.int32)

        sampler = UniformNeighborSampler(self.adj)
        agg_cls = {"mean": MeanAggregator, "gcn": GCNAggregator,
                   "maxpool": MaxPoolingAggregator, "meanpool": MeanPoolingAggregator, "seq": SeqAggregator}[aggregator]

        SAGEInfo = __import__('graphsage.models', fromlist=['SAGEInfo']).models.SAGEInfo
        layer_infos = [SAGEInfo("node", sampler, samples_1, dim_1),
                       SAGEInfo("node", sampler, samples_2, dim_2)]
        num_samples = [samples_1, samples_2]
        dims = [self.feats_np.shape[1], dim_1, dim_2]

        def sample(inputs, layer_infos, batch_size):
            samples = [inputs]; support_size = 1; support_sizes = [support_size]
            for k in range(len(layer_infos)):
                t = len(layer_infos) - k - 1
                support_size *= layer_infos[t].num_samples
                node = layer_infos[t].neigh_sampler((samples[k], layer_infos[t].num_samples))
                samples.append(tf.reshape(node, [support_size * batch_size,]))
                support_sizes.append(support_size)
            return samples, support_sizes

        def aggregate(samples, input_features, dims, num_samples, support_sizes, batch_size, aggregators=None, concat=True):
            hidden = [tf.nn.embedding_lookup(input_features, node_samples) for node_samples in samples]
            new_agg = aggregators is None
            if new_agg: aggregators = []
            for layer in range(len(num_samples)):
                if new_agg:
                    dim_mult = 2 if concat and (layer != 0) else 1
                    act = (lambda x: x) if (layer == len(num_samples)-1) else tf.nn.relu
                    aggregator = agg_cls(dim_mult*dims[layer], dims[layer+1], dropout=self.dropout_ph, concat=concat)
                    aggregators.append(aggregator)
                else:
                    aggregator = aggregators[layer]
                next_hidden = []
                for hop in range(len(num_samples) - layer):
                    dim_mult = 2 if concat and (layer != 0) else 1
                    neigh_dims = [batch_size * support_sizes[hop],
                                  num_samples[len(num_samples)-hop-1],
                                  dim_mult*dims[layer]]
                    h = aggregator((hidden[hop], tf.reshape(hidden[hop+1], neigh_dims)))
                    next_hidden.append(h)
                hidden = next_hidden
            return hidden[0], aggregators

        su, supp_u = sample(self.batch_u, layer_infos, self.batch_size_ph)
        sv, supp_v = sample(self.batch_v, layer_infos, self.batch_size_ph)
        h_u, aggs = aggregate(su, self.feats, dims, num_samples, supp_u, self.batch_size_ph, concat=self.concat)
        h_v, _    = aggregate(sv, self.feats, dims, num_samples, supp_v, self.batch_size_ph, aggregators=aggs, concat=self.concat)

        h_u = tf.nn.l2_normalize(h_u, axis=1)
        h_v = tf.nn.l2_normalize(h_v, axis=1)

        self.cp_emb = tf1.get_variable("cp_emb", shape=[int(json.load(open(OUT_PREFIX+"-meta.json")).get("cp_n_buckets",4096)), CP_EMB_DIM],
                                       initializer=xavier_init)
        emb_cp = tf.nn.embedding_lookup(self.cp_emb, self.cp_idx)

        edge_input = tf.concat([h_u, h_v, self.edge_feats, emb_cp], axis=1)
        in_dim = int(edge_input.shape[1])

        W1 = tf1.get_variable("edge_fc1", shape=[in_dim, 128], initializer=xavier_init)
        b1 = tf1.get_variable("edge_b1", shape=[128], initializer=zeros_init)
        W2 = tf1.get_variable("edge_fc2", shape=[128, 1], initializer=xavier_init)
        b2 = tf1.get_variable("edge_b2", shape=[1], initializer=zeros_init)

        z1 = tf.nn.relu(tf.matmul(edge_input, W1) + b1)
        z1 = tf.nn.dropout(z1, keep_prob=1.0 - self.dropout_ph)
        logits = tf.squeeze(tf.matmul(z1, W2) + b2, axis=1)
        self.probs = tf.nn.sigmoid(logits)

        self.loss = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
            targets=self.labels, logits=logits, pos_weight=self.pos_weight))
        self.opt = tf1.train.AdamOptimizer(self.lr).minimize(self.loss)

    def train_epoch(self, sess, loader, pos_weight, dropout):
        tot, n = 0.0, 0
        for (u,v,e,y,cp) in loader.iter_batches(BATCH_SIZE, "train", True):
            feed = { self.batch_size_ph: len(y), self.batch_u: u, self.batch_v: v,
                     self.edge_feats: e, self.labels: y.astype(np.float32),
                     self.cp_idx: cp, self.pos_weight: pos_weight, self.dropout_ph: dropout }
            _, L = sess.run([self.opt, self.loss], feed_dict=feed)
            tot += L*len(y); n += len(y)
        return tot/max(n,1)

    def predict_all(self, sess, loader, split="valid"):
        yt, pt = [], []
        for (u,v,e,y,cp) in loader.iter_batches(BATCH_SIZE, split, False):
            feed = { self.batch_size_ph: len(y), self.batch_u: u, self.batch_v: v,
                     self.edge_feats: e, self.cp_idx: cp, self.dropout_ph: 0.0 }
            p = sess.run(self.probs, feed_dict=feed)
            yt.append(y); pt.append(p)
        return np.concatenate(yt), np.concatenate(pt)

In [None]:
G, feats, id_map, adj, deg = load_graph_base(OUT_PREFIX)
loader = EdgeLoader(OUT_PREFIX)
print("Edge feat dim:", loader.F, "| train pos rate:", loader.pos_rate)

y_tr = loader.train[2]
pos_w = float((len(y_tr) - y_tr.sum()) / max(y_tr.sum(), 1))

tf1.reset_default_graph()
model = EdgeSupGraphSAGE(feats, adj, deg, aggregator=AGGREGATOR,
                         dim_1=DIM_1, dim_2=DIM_2, samples_1=SAMPLES_1, samples_2=SAMPLES_2,
                         dropout=DROPOUT, lr=LR, cp_n_buckets=int(json.load(open(OUT_PREFIX+"-meta.json")).get("cp_n_buckets",4096)),
                         cp_emb_dim=CP_EMB_DIM, concat=True)

config = tf1.ConfigProto(); config.gpu_options.allow_growth = True
sess = tf1.Session(config=config); sess.run(tf1.global_variables_initializer())

In [None]:
# for ep in range(1, EPOCHS+1):
#     t0 = time.time()
#     loss_tr = model.train_epoch(sess, loader, pos_weight=pos_w, dropout=DROPOUT)
#     yv, pv = model.predict_all(sess, loader, "valid")
#     aucpr = average_precision_score(yv, pv) if len(yv)>0 else np.nan
#     auc   = roc_auc_score(yv, pv) if len(np.unique(yv))>1 else np.nan
#     f1    = f1_score(yv, (pv>=0.5).astype(int), zero_division=0) if len(yv)>0 else np.nan
#     print(f"[Epoch {ep}] loss={loss_tr:.5f} | val AUCPR={aucpr:.4f} ROC-AUC={auc:.4f} F1@0.5={f1:.4f} | {time.time()-t0:.1f}s")

In [None]:
# yt, pt = model.predict_all(sess, loader, "test")
# aucpr_t = average_precision_score(yt, pt) if len(yt)>0 else np.nan
# auc_t   = roc_auc_score(yt, pt) if len(np.unique(yt))>1 else np.nan
# f1_t    = f1_score(yt, (pt>=0.5).astype(int), zero_division=0) if len(yt)>0 else np.nan
# print(f"[TEST] AUCPR={aucpr_t:.4f} ROC-AUC={auc_t:.4f} F1@0.5={f1_t:.4f}")

In [None]:
RESULTS_ROOT = os.path.join(DATA_DIR, "results")
tz = ZoneInfo("America/Sao_Paulo")
run_id = datetime.now(tz).strftime("%Y-%m-%d_%H-%M-%S")
RUN_DIR = os.path.join(RESULTS_ROOT, run_id)
CKPT_DIR = os.path.join(RUN_DIR, "ckpt")
PLOTS_DIR = os.path.join(RUN_DIR, "plots")
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

latest_link = os.path.join(RESULTS_ROOT, "latest")
try:
    if os.path.islink(latest_link) or os.path.exists(latest_link):
        if os.path.islink(latest_link):
            os.unlink(latest_link)
        else:
            shutil.rmtree(latest_link)
    os.symlink(RUN_DIR, latest_link)
except Exception:
    pass

exp_config = {
    "run_id": run_id,
    "data_dir": DATA_DIR,
    "proc_dir": PROC_DIR,
    "train_prefix": OUT_PREFIX,
    "epochs": int(EPOCHS),
    "dropout": float(DROPOUT),
    "pos_weight": float(pos_w) if 'pos_w' in globals() else None,
}
with open(os.path.join(RUN_DIR, "config.json"), "w") as f:
    json.dump(exp_config, f, indent=2)

def save_pr_roc(y_true, y_score, split_name):
    if len(y_true) == 0 or np.unique(y_true).size < 2:
        return
    p, r, _ = precision_recall_curve(y_true, y_score)
    ap = average_precision_score(y_true, y_score)
    plt.figure()
    plt.step(r, p, where="post")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"PR — {split_name} (AP={ap:.4f})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(PLOTS_DIR, f"{split_name}_pr_curve.png"), dpi=150)
    plt.close()
    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc = roc_auc_score(y_true, y_score)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0,1],[0,1],'--')
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(f"ROC — {split_name} (AUC={auc:.4f})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(PLOTS_DIR, f"{split_name}_roc_curve.png"), dpi=150)
    plt.close()

def maybe_get_txids(split):
    try:
        return loader.get_txids(split)
    except Exception:
        try:
            return getattr(loader, f"{split}_txids")
        except Exception:
            return None

def threshold_metrics(y_true, y_score, thr=0.5):
    if len(y_true) == 0:
        return {"precision": np.nan, "recall": np.nan, "f1": np.nan, "accuracy": np.nan}
    y_hat = (y_score >= thr).astype(int)
    return {
        "precision": precision_score(y_true, y_hat, zero_division=0),
        "recall":    recall_score(y_true, y_hat, zero_division=0),
        "f1":        f1_score(y_true, y_hat, zero_division=0),
        "accuracy":  accuracy_score(y_true, y_hat),
    }

history = []
best_val_aucpr = -np.inf
saver = tf.compat.v1.train.Saver(max_to_keep=3)

for ep in range(1, EPOCHS + 1):
    t0 = time.time()
    loss_tr = model.train_epoch(sess, loader, pos_weight=pos_w, dropout=DROPOUT)
    yv, pv = model.predict_all(sess, loader, "valid")
    aucpr = average_precision_score(yv, pv) if len(yv) > 0 else np.nan
    auc   = roc_auc_score(yv, pv) if np.unique(yv).size > 1 else np.nan
    m_val = threshold_metrics(yv, pv, thr=0.5)
    dur = time.time() - t0
    print(f"[Epoch {ep}] loss={loss_tr:.5f} | val AUCPR={aucpr:.4f} ROC-AUC={auc:.4f} "
          f"P={m_val['precision']:.4f} R={m_val['recall']:.4f} F1@0.5={m_val['f1']:.4f} Acc@0.5={m_val['accuracy']:.4f} | {dur:.1f}s")
    history.append({
        "epoch": ep,
        "train_loss": float(loss_tr),
        "val_aucpr": float(aucpr) if not np.isnan(aucpr) else None,
        "val_rocauc": float(auc)   if not np.isnan(auc)   else None,
        "val_precision_0.5": float(m_val["precision"]) if not np.isnan(m_val["precision"]) else None,
        "val_recall_0.5":    float(m_val["recall"])    if not np.isnan(m_val["recall"])    else None,
        "val_f1_0.5":        float(m_val["f1"])        if not np.isnan(m_val["f1"])        else None,
        "val_accuracy_0.5":  float(m_val["accuracy"])  if not np.isnan(m_val["accuracy"])  else None,
        "time_sec": float(dur),
    })
    pd.DataFrame(history).to_csv(os.path.join(RUN_DIR, "metrics_history.csv"), index=False)
    with open(os.path.join(RUN_DIR, "metrics_history.json"), "w") as f:
        json.dump(history, f, indent=2)
    if not np.isnan(aucpr) and aucpr > best_val_aucpr:
        best_val_aucpr = aucpr
        ckpt_path = os.path.join(CKPT_DIR, f"best_val_aucpr_ep{ep:03d}.ckpt")
        saver.save(sess, ckpt_path)
        with open(os.path.join(RUN_DIR, "best_val.json"), "w") as f:
            json.dump({"best_epoch": ep, "best_val_aucpr": float(best_val_aucpr), "ckpt": ckpt_path}, f, indent=2)

if len(yv) > 0:
    save_pr_roc(yv, pv, "valid")
    val_df = pd.DataFrame({"y_true": yv.astype(int), "y_score": pv.astype(float)})
    txv = maybe_get_txids("valid")
    if txv is not None and len(txv) == len(val_df):
        val_df.insert(0, "tx_id", txv)
    val_df.to_csv(os.path.join(RUN_DIR, "val_predictions.csv"), index=False)
    if np.unique(yv).size > 1:
        cm_v = confusion_matrix(yv, (pv >= 0.5).astype(int))
        with open(os.path.join(RUN_DIR, "val_confusion_matrix.json"), "w") as f:
            json.dump({"labels":[0,1], "matrix": cm_v.tolist()}, f, indent=2)

yt, pt = model.predict_all(sess, loader, "test")
aucpr_t = average_precision_score(yt, pt) if len(yt) > 0 else np.nan
auc_t   = roc_auc_score(yt, pt) if np.unique(yt).size > 1 else np.nan
m_test  = threshold_metrics(yt, pt, thr=0.5)
print(f"[TEST] AUCPR={aucpr_t:.4f} ROC-AUC={auc_t:.4f} "
      f"P={m_test['precision']:.4f} R={m_test['recall']:.4f} F1@0.5={m_test['f1']:.4f} Acc@0.5={m_test['accuracy']:.4f}")

test_metrics = {
    "test_aucpr":  float(aucpr_t) if not np.isnan(aucpr_t) else None,
    "test_rocauc": float(auc_t)   if not np.isnan(auc_t)   else None,
    "test_precision_0.5": float(m_test["precision"]) if not np.isnan(m_test["precision"]) else None,
    "test_recall_0.5":    float(m_test["recall"])    if not np.isnan(m_test["recall"])    else None,
    "test_f1_0.5":        float(m_test["f1"])        if not np.isnan(m_test["f1"])        else None,
    "test_accuracy_0.5":  float(m_test["accuracy"])  if not np.isnan(m_test["accuracy"])  else None,
}
with open(os.path.join(RUN_DIR, "test_metrics.json"), "w") as f:
    json.dump(test_metrics, f, indent=2)

if len(yt) > 0:
    save_pr_roc(yt, pt, "test")
    test_df = pd.DataFrame({"y_true": yt.astype(int), "y_score": pt.astype(float)})
    txt = maybe_get_txids("test")
    if txt is not None and len(txt) == len(test_df):
        test_df.insert(0, "tx_id", txt)
    test_df.to_csv(os.path.join(RUN_DIR, "test_predictions.csv"), index=False)
    if np.unique(yt).size > 1:
        cm_t = confusion_matrix(yt, (pt >= 0.5).astype(int))
        with open(os.path.join(RUN_DIR, "test_confusion_matrix.json"), "w") as f:
            json.dump({"labels":[0,1], "matrix": cm_t.tolist()}, f, indent=2)

if len(yt) > 0:
    k = min(200, len(yt))
    topk_idx = np.argsort(-pt)[:k]
    topk = {"index": topk_idx.tolist(),
            "y_true": yt[topk_idx].astype(int).tolist(),
            "y_score": pt[topk_idx].astype(float).tolist()}
    if 'txt' in locals() and txt is not None and len(txt) == len(yt):
        topk["tx_id"] = [str(txt[i]) for i in topk_idx]
    with open(os.path.join(RUN_DIR, "test_topk.json"), "w") as f:
        json.dump(topk, f, indent=2)

print(f"\nresultados salvos em:\n{RUN_DIR}\n(atalho: {latest_link if os.path.exists(latest_link) else RUN_DIR})")