# arXiv Papers

Predict research paper subject category

https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv

https://en.wikipedia.org/wiki/ArXiv

Tensorflow graph schema was generated using tensorflow_gnn's convert_ogb_dataset.py script


In [None]:
%pip install ogb
%pip install tensorflow_gnn

In [7]:
import functools
import itertools
import numpy as np
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import NodePropPredDataset
import os
import tensorflow_gnn as tfgnn
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn import runner
from tensorflow_gnn.models import mt_albis
import tensorflow as tf


In [8]:
# Load graph schema and tensor spec
graph_schema = tfgnn.read_schema("data/schema.pbtxt")
graph_tensor_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)

In [9]:
# Load dataset, index splits, and labels
dataset = NodePropPredDataset(name = "ogbn-arxiv", root = 'dataset/')
dataset.pre_process()
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, labels = dataset[0] # dicts representing node features & labels

In [10]:
# Load dataset into GraphTensor
use_mock_data = False
if use_mock_data:
    # Generate randomized mock data for model debugging on a resource-constrained machine
    num_nodes = 1000
    graph_tensor = tfgnn.random_graph_tensor(graph_tensor_spec, row_lengths_range=(num_nodes,num_nodes+1))
    train_idx = np.arange(0, 700, 1, dtype=int)
    valid_idx = np.arange(700, 850, 1, dtype=int)
    valid_idx = np.arange(850, 1000, 1, dtype=int)
    node_features = graph_tensor.node_sets["nodes"].get_features_dict()
    node_features["#id"] = np.arange(num_nodes)
    node_features["labels"] = labels[:1000]
    graph_tensor = graph_tensor.replace_features(node_sets={"nodes": node_features})
else:
    # Transform data dicts into GraphTensor
    # Compose tf.train.Example from graph dict, then read into GraphTensor
    # (Dict -> tf.train.Example -> GraphTensor)
    node_ids = tf.train.Feature(
        int64_list = tf.train.Int64List(value=np.arange(len(labels)))
    )
    num_nodes = tf.train.Feature(
        int64_list = tf.train.Int64List(value=[len(labels)])
    )
    node_years = tf.train.Feature(
        int64_list = tf.train.Int64List(value=graph["node_year"].flatten())
    )
    node_features = tf.train.Feature(
        float_list = tf.train.FloatList(value=graph["node_feat"].flatten())
    )
    node_labels = tf.train.Feature(
        int64_list = tf.train.Int64List(value=labels.flatten())
    )
    edges_size = tf.train.Feature(
        int64_list = tf.train.Int64List(value=[len(graph["edge_index"][0])])
    )
    edge_sources = tf.train.Feature(
        int64_list = tf.train.Int64List(value=graph["edge_index"][0])
    )
    edge_targets = tf.train.Feature(
        int64_list = tf.train.Int64List(value=graph["edge_index"][1])
    )
    graph_example = tf.train.Example(
        features=tf.train.Features(feature={
            "nodes/nodes.#id": node_ids,
            "nodes/nodes.#size": num_nodes,
            "nodes/nodes.year": node_years,
            "nodes/nodes.feat": node_features,
            "nodes/nodes.labels": node_labels,
            "edges/edges.#size": edges_size,
            "edges/edges.#source": edge_sources,
            "edges/edges.#target": edge_targets
        })
    )
    # Parse the tf.train.Example graph into graph tensor
    graph_tensor = tfgnn.parse_single_example(graph_tensor_spec, graph_example.SerializeToString())

In [12]:
# Define model layers
# Configure feature processor layers
add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode("seed", node_set_name="nodes")
move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature("seed", feature_name="labels")
feature_processors = [
    add_readout,
    move_label_to_readout,
]


# Set model hyperparameters
num_graph_updates = 4
node_state_dim = 129
message_dim = 256
edge_dropout_rate = 0
state_dropout_rate = 0.2
l2_regularization = 1e-5


# Define node mapping function
def set_initial_node_state(node_set, node_set_name):
    return tf.keras.layers.Concatenate()(
        [node_set["feat"],
         tf.cast(node_set["year"], tf.float32) / 2000]  # Normalize year
    )


# Configure model generator for core gnn update layers
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
    graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
    graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_state)(graph)
    for i in range(num_graph_updates):
        graph = mt_albis.MtAlbisGraphUpdate(
            attention_type="none",
            units=node_state_dim,
            message_dim=message_dim,
            receiver_tag=tfgnn.SOURCE,
            node_set_names=None if i < num_graph_updates-1 else ["nodes"],
            simple_conv_reduce_type="mean|sum",
            edge_dropout_rate=edge_dropout_rate,
            state_dropout_rate=state_dropout_rate,
            l2_regularization=l2_regularization,
            normalization_type="layer",
            next_state_type="residual",
        )(graph)
    return tf.keras.Model(inputs, graph)

In [13]:
# Configure KerasTrainer
# Set training hyperparameters
if use_mock_data:
    global_batch_size = 50
else:
    global_batch_size = 5000
epochs = 10
initial_learning_rate = 0.001
steps_per_epoch = len(train_idx) // global_batch_size
validation_steps = len(valid_idx) // global_batch_size
learning_rate = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, steps_per_epoch*epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.legacy.Adam, learning_rate=learning_rate)
task = runner.NodeMulticlassClassification(
    num_classes=40,
    label_feature_name="labels")

# Initialize trainer
trainer = runner.KerasTrainer(
    strategy=tf.distribute.get_strategy(),
    model_dir="gnn_model/",
    callbacks=None,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    restore_best_weights=False,
    checkpoint_every_n_steps="never",
    summarize_every_n_steps="never",
    backup_and_restore=False,
)

# Initialize model exporter
model_exporter = runner.KerasModelExporter(output_names="paper_category_logits")

In [None]:
# Build sampling spec & sampling model
sampling_spec = tfgnn.sampler.SamplingSpecBuilder(graph_schema, tfgnn.sampler.SamplingStrategy.RANDOM_UNIFORM).seed("nodes").sample(5, "edges").build()

def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
  return sampler.InMemUniformEdgesSampler.from_graph_tensor(
      graph_tensor, sampling_op.edge_set_name, sample_size=sampling_op.sample_size, name=sampling_op.op_name
  )

def node_feature_accessor(node_set_name: tfgnn.NodeSetName):
  return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
      graph_tensor, node_set_name
  )

sampling_model = sampler.create_sampling_model_from_spec(
    graph_schema, sampling_spec, edge_sampler, node_feature_accessor, seed_node_dtype=tf.int32)

In [16]:
# Subclass DatasetProvider, wrapping the sampling spec & model defined above
class SubgraphDatasetProvider(runner.DatasetProvider):
  def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sampling_model: tf.keras.Model, sampling_spec: tfgnn.sampler.SamplingSpec, split_idx: np.ndarray):
    super().__init__()
    self.graph_tensor = full_graph_tensor
    self.sampling_model = sampling_model
    self.sampling_spec = sampling_spec
    self.split_idx = split_idx

  def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
    # returns sampled tf.data.Dataset
    seeds = tf.data.Dataset.from_tensor_slices(self.split_idx)
    seeds = seeds.batch(1)
    seeds = seeds.map(lambda s: tf.RaggedTensor.from_row_lengths(s, tf.ones_like(s)))
    seeds = seeds.map(self.sampling_model)
    return seeds.unbatch().prefetch(tf.data.AUTOTUNE)
  
train_ds_provider = SubgraphDatasetProvider(graph_tensor, sampling_model, sampling_spec, train_idx)
valid_ds_provider = SubgraphDatasetProvider(graph_tensor, sampling_model, sampling_spec, valid_idx)
test_ds_provider = SubgraphDatasetProvider(graph_tensor, sampling_model, sampling_spec, test_idx)

In [None]:
# GNN runner orchestrates model training
runner.run(
    gtspec=graph_tensor_spec,
    train_ds_provider=train_ds_provider,
    valid_ds_provider=valid_ds_provider,
    global_batch_size=global_batch_size,
    epochs=epochs,
    feature_processors=feature_processors,
    model_fn=model_fn,
    task=task,
    optimizer_fn=optimizer_fn,
    trainer=trainer,
    model_exporters=[model_exporter],
)

In [36]:
# Load model.
saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, "export"))
signature_fn = saved_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

def _clean_example_for_serving(graph_tensor):
  graph_tensor = graph_tensor.remove_features(node_sets={"nodes": ["labels"]})
  serialized_example = tfgnn.write_example(graph_tensor)
  return serialized_example.SerializeToString()

# Convert examples to serialized string format.
num_test_cases = len(test_idx)
test_ds_provider = SubgraphDatasetProvider(graph_tensor, sampling_model, sampling_spec, test_idx[:num_test_cases])
demo_ds = test_ds_provider.get_dataset(tf.distribute.InputContext())
serialized_examples = [_clean_example_for_serving(gt) for gt in itertools.islice(demo_ds, num_test_cases)]

# Inference on training dataset
ds = tf.data.Dataset.from_tensor_slices(serialized_examples)
# The name "examples" for serialized tf.Example protos is defined by the runner.
input_dict = {"examples": next(iter(ds.batch(num_test_cases)))}

# Outputs are in the form of logits.
output_dict = signature_fn(**input_dict)
logits = output_dict["paper_category_logits"]  # As configured above.
probabilities = tf.math.softmax(logits).numpy()
classes = probabilities.argmax(axis=1)
  
# OGB evaluator
evaluator = Evaluator(name = "ogbn-arxiv")
y_true = np.take(labels, np.asarray(np.split(test_idx[:num_test_cases], num_test_cases)))
y_pred = np.asarray(np.split(classes, len(classes)))
ogb_evaluator_dict = {"y_true": y_true, "y_pred": y_pred}
result_dict = evaluator.eval(ogb_evaluator_dict)
print(result_dict)  # Current accuracy -> {'acc': 0.6173487233298356}


{'acc': 0.6173487233298356}
