In [67]:
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
from ogb.nodeproppred import NodePropPredDataset
import numpy as np
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner
from typing import Mapping
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn.models import mt_albis
from tensorflow.keras.callbacks import TensorBoard
import functools

In [40]:
dataset = NodePropPredDataset(name = "ogbn-products", root = 'dataset/')
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
NUM_TRAINING_SAMPLES=train_idx.shape[0]
NUM_VALIDATION_SAMPLES = valid_idx.shape[0]

In [70]:
graph, label = dataset[0]
mask0=np.isin(graph["edge_index"][0],train_idx)
mask1=np.isin(graph["edge_index"][1],train_idx)

mask = mask0 & mask1
indices = np.where(mask)[0]
train_edge_index=graph["edge_index"][:, indices]
train_node_feat = graph["node_feat"][train_idx,:]
train_label = label[train_idx]

In [None]:
graph, label = dataset[0]
mask0=np.isin(graph["edge_index"][0],train_idx)
mask1=np.isin(graph["edge_index"][1],train_idx)

mask = mask0 & mask1
indices = np.where(mask)[0]
train_edge_index=graph["edge_index"][:, indices]
train_node_feat = graph["node_feat"][train_idx,:]
train_label = label[train_idx]

In [42]:
# graph_schema = tfgnn.read_schema("graph_schema.pbtxt")
# graph_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)
# train_dataset_provider = runner.TFRecordDatasetProvider(file_pattern="train.tfrecord")
# train_dataset = train_dataset_provider.get_dataset(context=tf.distribute.InputContext())
# train_dataset = train_dataset.map(lambda serialized: tfgnn.parse_single_example(serialized=serialized, spec=graph_spec))
# graph_tensor = train_dataset.get_single_element()

In [43]:
graph_schema = tfgnn.read_schema("graph_schema.pbtxt")

graph_tensor = tfgnn.GraphTensor.from_pieces(
    node_sets={
       "product": tfgnn.NodeSet.from_fields(
           sizes = tf.constant([train_idx.shape[0]]),
           features={
               "id": tf.constant(train_idx),
               "feature": tf.constant(train_node_feat),
               "label": tf.constant(train_label)
           }
       )
   },
    edge_sets = {
      "bought_together": tfgnn.EdgeSet.from_fields(
          sizes = tf.constant([train_edge_index.shape[1]]),
          adjacency = tfgnn.Adjacency.from_indices(
              source = ("product", tf.constant(train_edge_index[0,:])),
              target = ("product", tf.constant(train_edge_index[1,:])),
          )
      )  
    },
)

In [44]:
train_sampling_sizes = {
    "bought_together": 8,
}

def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model:

    def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):    
        edge_set_name = sampling_op.edge_set_name
        sample_size = sizes[edge_set_name]
        return sampler.InMemUniformEdgesSampler.from_graph_tensor(
            full_graph_tensor,
            edge_set_name, sample_size=sample_size
        )
    def get_features(node_set_name: tfgnn.NodeSetName):
        return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
            full_graph_tensor,
            node_set_name
        )

    sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
    seed = sampling_spec_builder.seed("product")
    products_bought_together = seed.sample(sizes["bought_together"], "bought_together", op_name="bt")
    sampling_spec = sampling_spec_builder.build()
    model = sampler.create_sampling_model_from_spec(graph_schema, sampling_spec, edge_sampler, get_features, seed_node_dtype=tf.int64)
    return model

In [45]:
# def node_sets_fn(node_set, *, node_set_name):
#     features = node_set.get_features_dict()
#     ids = features.pop('id')
#     num_bins = 50000
#     features['hashed_id'] = tf.cast(tf.keras.layers.Hashing(num_bins=num_bins)(ids), tf.int32)
#     return features
# graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=node_sets_fn)(graph_tensor)

In [46]:
class SubgraphDatasetProvider(runner.DatasetProvider):
    "Dataset provider"

    def __init__(self,
                full_graph_tensor: tfgnn.GraphTensor,
                sizes: Mapping[str, int],
                dataset: tf.data.Dataset):
        self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
        self.input_graph_spec = self._sampling_model.output.spec
        self._seed_dataset = dataset
        
    def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
        """Creates TF dataset"""
        ds = self._seed_dataset.shard(num_shards=context.num_input_pipelines, index = context.input_pipeline_id)
        ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
        ds = ds.batch(128)
        ds = ds.map(
            functools.partial(self.sample),
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False,
        )
        return ds.unbatch().prefetch(tf.data.AUTOTUNE)

    def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
        # seeds = tf.cast(seeds, tf.int32)
        batch_size = tf.size(seeds)
        # print(f"batch_size={batch_size}")
        seeds_ragged = tf.RaggedTensor.from_row_lengths(seeds, tf.ones([batch_size], dtype=tf.int64))
        return self._sampling_model(seeds_ragged)

In [47]:
train_ds_provider = SubgraphDatasetProvider(graph_tensor, train_sampling_sizes, tf.data.Dataset.from_tensor_slices(train_idx))
valid_ds_provider = SubgraphDatasetProvider(graph_tensor, train_sampling_sizes, tf.data.Dataset.from_tensor_slices(valid_idx))

example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()

In [48]:
if tf.config.list_physical_devices("TPU"):
  print(f"Using TPUStrategy")
  min_nodes_per_component = {"paper": 1}
  strategy = runner.TPUStrategy("local")
  train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider, min_nodes_per_component)
  valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider, min_nodes_per_component)
elif tf.config.list_physical_devices("GPU"):
  print(f"Using MirroredStrategy for GPUs")
  gpu_list = !nvidia-smi -L
  print("\n".join(gpu_list))
  strategy = tf.distribute.MirroredStrategy()
  train_padding = None
  valid_padding = None
else:
  print(f"Using default strategy")
  strategy = tf.distribute.get_strategy()
  train_padding = None
  valid_padding = None
print(f"Found {strategy.num_replicas_in_sync} replicas in sync")

Using MirroredStrategy for GPUs
GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-b2dd31a4-1d4c-6d81-647d-e36ea0e64af9)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Found 1 replicas in sync


In [49]:
def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
    if node_set_name == "product":
        return {"feature": node_set["feature"], "label": node_set["label"]}
    raise KeyError(f"Unexpected node_set_name='{node_set_name}'")

def drop_all_features(_, **unused_kwargs):
    return {}

process_features = tfgnn.keras.layers.MapFeatures(
    context_fn=drop_all_features,
    node_sets_fn=process_node_features,
    edge_sets_fn=drop_all_features,
)

In [50]:
add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode("seed", node_set_name="product")
move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
    "seed", feature_name="label", new_feature_name="category", remove_input_feature=True
)

In [51]:
feature_processors = [
    process_features,
    add_readout,
    move_label_to_readout,
]

In [52]:
node_state_dim = 128

def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
    if node_set_name == "product":
        return tf.keras.layers.Dense(node_state_dim, "relu")(node_set["feature"])
    raise KeyError(f"Unexpected node_set_name='{node_set_name}'")

In [53]:
num_graph_updates = 1
message_dim = 128
state_dropout_rate = 0.2
l2_regularization= 1e-5

def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
    graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
    graph = tfgnn.keras.layers.MapFeatures(
        node_sets_fn = set_initial_node_states)(graph)
    for i in range(num_graph_updates):
        graph = mt_albis.MtAlbisGraphUpdate(
            units = node_state_dim,
            message_dim = message_dim,
            receiver_tag = tfgnn.SOURCE,
            node_set_names = None if i < num_graph_updates - 1 else ["product"],
            simple_conv_reduce_type="mean|sum",
            state_dropout_rate=state_dropout_rate,
            l2_regularization=l2_regularization,
            normalization_type="layer",
            next_state_type="residual",
        )(graph)
    return tf.keras.Model(inputs, graph)

In [54]:
# np.unique(label.flatten()).shape gives 47
task = runner.NodeMulticlassClassification(
    num_classes=47,
    label_feature_name="category",
)

In [68]:
global_batch_size = 128
epochs = 10
initial_learning_rate = 0.001

if tf.config.list_physical_devices("TPU"):
    epoch_divisor = 1
else:
    epoch_divisor = 1

steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size // epoch_divisor
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size // epoch_divisor

learning_rate = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate, steps_per_epoch*epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)

trainer = runner.KerasTrainer(
    strategy=strategy,
    model_dir="/tmp/gnn_model",
    callbacks=[TensorBoard(log_dir="logs")],
    steps_per_epoch=steps_per_epoch,
    # validation_steps=validation_steps,
    restore_best_weights=False,
    checkpoint_every_n_steps="never",
    summarize_every_n_steps="never",
    backup_and_restore=False
)

In [69]:
model_exporter = runner.KerasModelExporter(output_names="product_category")
runner.run(
    gtspec=example_input_graph_spec,
    train_ds_provider=train_ds_provider,
    train_padding=train_padding,
    valid_ds_provider=None,
    # valid_ds_provider=valid_ds_provider,
    valid_padding=valid_padding,
    global_batch_size=global_batch_size,
    epochs=epochs,
    feature_processors=feature_processors,
    model_fn=model_fn,
    task=task,
    optimizer_fn=optimizer_fn,
    trainer=trainer,
    model_exporters=[model_exporter],
)

batch_size=Tensor("Size:0", shape=(), dtype=int32)
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10




INFO:tensorflow:Assets written to: /tmp/gnn_model/export/assets


INFO:tensorflow:Assets written to: /tmp/gnn_model/export/assets


RunResult(preprocess_model=<tf_keras.src.engine.functional.Functional object at 0x7b0ddffeffb0>, base_model=<tf_keras.src.engine.sequential.Sequential object at 0x7b0f19383e30>, trained_model=<tf_keras.src.engine.functional.Functional object at 0x7b0f2940d370>)

In [25]:
example_input_graph_spec

GraphTensorSpec({'context': ContextSpec({'features': {}, 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None), 'node_sets': {'product': NodeSetSpec({'features': {'id': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'feature': TensorSpec(shape=(None, 100), dtype=tf.int32, name=None), 'label': TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)}, 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None)}, 'edge_sets': {'bought_together': EdgeSetSpec({'features': {}, 'adjacency': AdjacencySpec({'#index.0': TensorSpec(shape=(None,), dtype=tf.int32, name=None), '#index.1': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, {'#index.0': 'product', '#index.1': 'product'}), 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None)}}, TensorShape([]), tf.int32, tf.int64, None)

In [16]:
model = create_sampling_model(graph_tensor, train_sampling_sizes)
model.output
# graph.edge_sets



<KerasTensor: type_spec=GraphTensorSpec({'context': ContextSpec({'features': {}, 'sizes': TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)}, TensorShape([None]), tf.int32, tf.int64, None), 'node_sets': {'product': NodeSetSpec({'features': {'id': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int64), 'feature': RaggedTensorSpec(TensorShape([None, None, 100]), tf.float32, 1, tf.int64), 'label': RaggedTensorSpec(TensorShape([None, None, 1]), tf.int64, 1, tf.int64)}, 'sizes': TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)}, TensorShape([None]), tf.int32, tf.int64, None)}, 'edge_sets': {'bought_together': EdgeSetSpec({'features': {}, 'adjacency': AdjacencySpec({'#index.0': RaggedTensorSpec(TensorShape([None, None]), tf.int32, 1, tf.int64), '#index.1': RaggedTensorSpec(TensorShape([None, None]), tf.int32, 1, tf.int64)}, TensorShape([None]), tf.int32, tf.int64, {'#index.0': 'product', '#index.1': 'product'}), 'sizes': TensorSpec(shape=(None, 1), dtype=tf.int32, name

In [14]:
graph_tensor.edge_sets

{'bought_together': EdgeSet(features={}, sizes=[10903266], adjacency=Adjacency(source=('product', <tf.Tensor: shape=(10903266,), dtype=tf.int64>), target=('product', <tf.Tensor: shape=(10903266,), dtype=tf.int64>)))}