# Introduction to TensorFlow Graph Neural Networks:

* https://github.com/tensorflow/gnn

## Graph classification example from 
* https://www.kaggle.com/code/fidels/introduction-to-tf-gnn
## Updated to work with the current version of TF-GNN

In [1]:
from IPython.display import clear_output

## Imports

### For plotting graph model install and import:

In [None]:
#import pygraphviz as pgv

In [None]:
from tqdm import tqdm
from IPython.display import Image
import numpy as np
import tensorflow as tf

# Useful for debugging
#tf.config.run_functions_eagerly(
#    True
#)

tf.get_logger().setLevel('ERROR')

import tensorflow_gnn as tfgnn
import tensorflow_datasets as tfds

from tensorflow_gnn import runner
from tensorflow_gnn.models import gat_v2

print(f'Using TensorFlow v{tf.__version__} and TensorFlow-GNN v{tfgnn.__version__}')
print(f'GPUs available: {tf.config.list_physical_devices("GPU")}')

## Loading chemical molecules dataset
## DS contains splits into: train, test-iid, test-ood1, test-ood2.

In [3]:
dataset_splits, dataset_info = tfds.load('cardiotox', data_dir='data/tfds', with_info=True)

clear_output()

print(dataset_info.description)

Drug Cardiotoxicity dataset [1-2] is a molecule classification task to detect
cardiotoxicity caused by binding hERG target, a protein associated with heart
beat rhythm. The data covers over 9000 molecules with hERG activity.

Note:

1. The data is split into four splits: train, test-iid, test-ood1, test-ood2.

2. Each molecule in the dataset has 2D graph annotations which is designed to
facilitate graph neural network modeling. Nodes are the atoms of the molecule
and edges are the bonds. Each atom is represented as a vector encoding basic
atom information such as atom type. Similar logic applies to bonds.

3. We include Tanimoto fingerprint distance (to training data) for each molecule
in the test sets to facilitate research on distributional shift in graph domain.

For each example, the features include:
  atoms: a 2D tensor with shape (60, 27) storing node features. Molecules with
    less than 60 atoms are padded with zeros. Each atom has 27 atom features.
  pairs: a 3D tensor with 

## Detailed detaset informations including numbers of samples 

In [4]:
dataset_info

tfds.core.DatasetInfo(
    name='cardiotox',
    full_name='cardiotox/1.0.0',
    description="""
    Drug Cardiotoxicity dataset [1-2] is a molecule classification task to detect
    cardiotoxicity caused by binding hERG target, a protein associated with heart
    beat rhythm. The data covers over 9000 molecules with hERG activity.
    
    Note:
    
    1. The data is split into four splits: train, test-iid, test-ood1, test-ood2.
    
    2. Each molecule in the dataset has 2D graph annotations which is designed to
    facilitate graph neural network modeling. Nodes are the atoms of the molecule
    and edges are the bonds. Each atom is represented as a vector encoding basic
    atom information such as atom type. Similar logic applies to bonds.
    
    3. We include Tanimoto fingerprint distance (to training data) for each molecule
    in the test sets to facilitate research on distributional shift in graph domain.
    
    For each example, the features include:
      atoms: a 2D

## A brief look into the raw training dataset tensors forming one molecule sample

In [5]:
sample = next(iter(dataset_splits['train']))

## Sample is a dictionary of tensors

In [6]:
sample.keys()

dict_keys(['active', 'atom_mask', 'atoms', 'dist2topk_nbs', 'molecule_id', 'pair_mask', 'pairs'])

In [7]:
sample

{'active': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>,
 'atom_mask': <tf.Tensor: shape=(60,), dtype=float32, numpy=
 array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 'atoms': <tf.Tensor: shape=(60, 27), dtype=float32, numpy=
 array([[0., 1., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
 'dist2topk_nbs': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'molecule_id': <tf.Tensor: shape=(), dtype=string, numpy=b'CC1=C(C/C=C(\\C)CCC[C@H](C)CCC[C@H](C)CCCC(C)C)C(=O)c2ccccc2C1=O'>,
 'pair_mask': <tf.Tensor: shape=(6

## TF-GNN works with graphs by employing a particular data representations defined by a *Graph Schema*
## Schema is used to prepare a GraphTensorSpec

In [8]:
graph_schema_pbtxt = """
node_sets {
  key: "atom"
  value {
    description: "An atom in the molecule."

    features {
      key: "atom_features"
      value: {
        description: "[DATA] The features of the atom."
        dtype: DT_FLOAT
        shape { dim { size: 27 } }
      }
    }
  }
}

edge_sets {
  key: "bond"
  value {
    description: "A bond between two atoms in the molecule."
    source: "atom"
    target: "atom"

    features {
      key: "bond_features"
      value: {
        description: "[DATA] The features of the bond."
        dtype: DT_FLOAT
        shape { dim { size: 12 } }
      }
    }
  }
}

context {
  features {
    key: "toxicity"
    value: {
      description: "[LABEL] The toxicity class of the molecule (0 -> non-toxic; 1 -> toxic)."
      dtype: DT_INT64
    }
  }
  
  features {
    key: "molecule_id"
    value: {
      description: "[LABEL] The id of the molecule."
      dtype: DT_STRING
    }
  }
}
"""

In [9]:
graph_schema = tfgnn.parse_schema(graph_schema_pbtxt)
graph_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)

In [10]:
graph_spec

GraphTensorSpec({'context': ContextSpec({'features': {'toxicity': TensorSpec(shape=(1,), dtype=tf.int64, name=None), 'molecule_id': TensorSpec(shape=(1,), dtype=tf.string, name=None)}, 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None), 'node_sets': {'atom': NodeSetSpec({'features': {'atom_features': TensorSpec(shape=(None, 27), dtype=tf.float32, name=None)}, 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None)}, 'edge_sets': {'bond': EdgeSetSpec({'features': {'bond_features': TensorSpec(shape=(None, 12), dtype=tf.float32, name=None)}, 'adjacency': AdjacencySpec({'#index.0': TensorSpec(shape=(None,), dtype=tf.int32, name=None), '#index.1': TensorSpec(shape=(None,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, {'#index.0': 'atom', '#index.1': 'atom'}), 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None)}}, Tensor

## A function is defined to import data from the the raw dataset tensors into structures defined by the schema
## The raw dataset contains paddings for both atoms (nodes) and bonds (edges), which are removed with corresponding masks provided in the dataset

In [11]:
def make_graph_tensor(datapoint):
    """
    Convert a datapoint from the TF-DS CardioTox dataset into a `GraphTensor`.
    """
    # atom_mask is non-zero only for real atoms
    # [ V, ]
    atom_indices = tf.squeeze(tf.where(datapoint['atom_mask']), axis=1)
    
    # only keep features of real atoms
    # [ V, 27 ]
    atom_features = tf.gather(datapoint['atoms'], atom_indices)
    
    # restrict the bond mask to real atoms
    # [ V, V ]
    pair_mask = tf.gather(tf.gather(datapoint['pair_mask'], atom_indices, axis=0), atom_indices, axis=1)
    
    # restrict the bond features to real atoms
    # [ V, V, 12 ]
    pairs = tf.gather(tf.gather(datapoint['pairs'], atom_indices, axis=0), atom_indices, axis=1)
    
    # pair_mask is non-zero only for real bonds
    # [ E, 2 ]
    bond_indices = tf.where(pair_mask)
    
    # only keep features of real bonds
    # [ E, 12 ]
    bond_features = tf.gather_nd(pairs, bond_indices)
    
    # separate sources and targets for each bond
    # [ E, ]
    sources, targets = tf.unstack(tf.transpose(bond_indices))

    # active is [1, 0] for non-toxic molecules, [0, 1] for toxic molecules
    # [ ]
    toxicity = tf.argmax(datapoint['active'])
    
    # the molecule_id is included for reference
    # [ ]
    molecule_id = datapoint['molecule_id']

    # create a GraphTensor from all of the above
    atom = tfgnn.NodeSet.from_fields(features={'atom_features': atom_features},
                                     sizes=tf.shape(atom_indices))
    
    atom_adjacency = tfgnn.Adjacency.from_indices(source=('atom', tf.cast(sources, dtype=tf.int32)),
                                                  target=('atom', tf.cast(targets, dtype=tf.int32)))
    
    bond = tfgnn.EdgeSet.from_fields(features={'bond_features': bond_features},
                                     sizes=tf.shape(sources),
                                     adjacency=atom_adjacency)
    
    context = tfgnn.Context.from_fields(features={'toxicity': [toxicity], 'molecule_id': [molecule_id]})
    
    return tfgnn.GraphTensor.from_pieces(node_sets={'atom': atom}, edge_sets={'bond': bond}, context=context)


## Mapping the dataset into the ragged form without padded graph elements

In [12]:
train_dataset = dataset_splits['train'].map(make_graph_tensor)

## Taking one sample, same as the one in the raw dataset above

In [13]:
graph_tensor = next(iter(train_dataset))
graph_tensor

GraphTensor(
  context=Context(features={'toxicity': <tf.Tensor: shape=(1,), dtype=tf.int64>, 'molecule_id': <tf.Tensor: shape=(1,), dtype=tf.string>}, sizes=[1], shape=(), indices_dtype=tf.int32),
  node_set_names=['atom'],
  edge_set_names=['bond'])

## Verification, that the resulting dataset is compliant with the graph schema

In [14]:
graph_spec.is_compatible_with(graph_tensor)

True

## Looking into the transformed tensors

## Edge and atom sets below have sizes varying based on the actual atoms and bond counts in molecules (feature dimensions are the same 12 and 27)

In [15]:
graph_tensor.spec

GraphTensorSpec({'context': ContextSpec({'features': {'toxicity': TensorSpec(shape=(1,), dtype=tf.int64, name=None), 'molecule_id': TensorSpec(shape=(1,), dtype=tf.string, name=None)}, 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None), 'node_sets': {'atom': NodeSetSpec({'features': {'atom_features': TensorSpec(shape=(33, 27), dtype=tf.float32, name=None)}, 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None)}, 'edge_sets': {'bond': EdgeSetSpec({'features': {'bond_features': TensorSpec(shape=(68, 12), dtype=tf.float32, name=None)}, 'adjacency': AdjacencySpec({'#index.0': TensorSpec(shape=(68,), dtype=tf.int32, name=None), '#index.1': TensorSpec(shape=(68,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, {'#index.0': 'atom', '#index.1': 'atom'}), 'sizes': TensorSpec(shape=(1,), dtype=tf.int32, name=None)}, TensorShape([]), tf.int32, tf.int64, None)}}, TensorShape([]

## Edges are represented as two source, target lists 

In [16]:
graph_tensor.edge_sets

{'bond': EdgeSet(features={'bond_features': <tf.Tensor: shape=(68, 12), dtype=tf.float32>}, sizes=[68], adjacency=Adjacency(source=('atom', <tf.Tensor: shape=(68,), dtype=tf.int32>), target=('atom', <tf.Tensor: shape=(68,), dtype=tf.int32>)))}

In [17]:
# Number of edges E in this molecule graph
graph_tensor.edge_sets['bond'].sizes

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([68], dtype=int32)>

In [18]:
graph_tensor.edge_sets['bond'].adjacency.source

<tf.Tensor: shape=(68,), dtype=int32, numpy=
array([ 0,  1,  1,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  5,  6,  7,  7,
        8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 14, 14, 15, 15, 15,
       16, 17, 17, 18, 18, 19, 19, 20, 20, 20, 21, 22, 23, 23, 23, 24, 25,
       25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 30, 31, 31, 31, 32],
      dtype=int32)>

In [19]:
graph_tensor.edge_sets['bond'].adjacency.target

<tf.Tensor: shape=(68,), dtype=int32, numpy=
array([ 1,  0,  2, 31,  1,  3, 23,  2,  4,  3,  5,  4,  6,  7,  5,  5,  8,
        7,  9,  8, 10,  9, 11, 12, 10, 10, 13, 12, 14, 13, 15, 14, 16, 17,
       15, 15, 18, 17, 19, 18, 20, 19, 21, 22, 20, 20,  2, 24, 25, 23, 23,
       26, 30, 25, 27, 26, 28, 27, 29, 28, 30, 25, 29, 31,  1, 30, 32, 31],
      dtype=int32)>

## Each edge in this unoriented graph is endowed with 12 features

In [20]:
graph_tensor.edge_sets['bond']['bond_features']

<tf.Tensor: shape=(68, 12), dtype=float32, numpy=
array([[1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 0., 0., 0., 1., 1., 1., 1.

## Nodes (atoms) are simply listed as their feature vectors of dim==27

In [21]:
graph_tensor.node_sets['atom']

NodeSet(features={'atom_features': <tf.Tensor: shape=(33, 27), dtype=tf.float32>}, sizes=[33])

In [22]:
# Number of nodes V in his molecules
graph_tensor.node_sets['atom'].sizes

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([33], dtype=int32)>

In [23]:
graph_tensor.node_sets['atom']['atom_features']

<tf.Tensor: shape=(33, 27), dtype=float32, numpy=
array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
   

## The target for the GNN binary classifier is stored in the molecule context

In [24]:
graph_tensor.context.features['toxicity']

<tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>

## For performance we store the dataset in TFRecords, which may allow for distributed GPU and TPU training

## From the original article:
### "tf.data.Dataset.cache or tf.data.Dataset.snapshot would be preferable, as they would allow for more optimizations such as e.g. sharding."

## These methods allow for file or memory persistence of preprocessed data to save on dataset preparation

## The code below just dumps data to TFR without multiple files optimization

In [25]:
def create_tfrecords(dataset_splits, dataset_info):
    """
    Dump all splits of the given dataset to TFRecord files.
    """
    for split_name, dataset in dataset_splits.items():
        filename = f'data/{dataset_info.name}-{split_name}.tfrecord'
        
        print(f'Creating {filename}')

        # Mapping the raw dataser to GraphTensors with parallel processing
        dataset = dataset.map(make_graph_tensor, num_parallel_calls=tf.data.AUTOTUNE)

        # Serializing to TFRs
        with tf.io.TFRecordWriter(filename) as writer:
            # We explicitly limit iteration over the dataset to one pass
            for graph_tensor in tqdm(iter(dataset), total=dataset_info.splits[split_name].num_examples):
                example = tfgnn.write_example(graph_tensor)
                writer.write(example.SerializeToString())


## Creating TFRecord files

In [26]:
#create_tfrecords(dataset_splits, dataset_info)

## Dataset retrieval method: TFRecordDataserProvider

In [27]:
train_dataset_provider = runner.TFRecordDatasetProvider(file_pattern='data/cardiotox-train.tfrecord')
valid_dataset_provider = runner.TFRecordDatasetProvider(file_pattern='data/cardiotox-validation.tfrecord')
test1_dataset_provider = runner.TFRecordDatasetProvider(file_pattern='data/cardiotox-test.tfrecord')
test2_dataset_provider = runner.TFRecordDatasetProvider(file_pattern='data/cardiotox-test2.tfrecord')

In [28]:
train_dataset_provider

<tensorflow_gnn.runner.input.datasets.TFRecordDatasetProvider at 0x7ff1277528b0>

In [29]:
train_dataset_provider.get_dataset

<bound method SimpleDatasetProvider.get_dataset of <tensorflow_gnn.runner.input.datasets.TFRecordDatasetProvider object at 0x7ff1277528b0>>

## Retrieving the dataset

## tf.distribute.InputContext() automatically determines batching and sharding for distributed training with many replicas in sync (if any)

In [30]:
train_dataset = train_dataset_provider.get_dataset(context=tf.distribute.InputContext())

## Unpacking TFRecord dataset from the serialized form requires mapping with a providied tensor spec for the data to deserialize

In [31]:
train_dataset = train_dataset.map(lambda serialized: tfgnn.parse_single_example(serialized=serialized, spec=graph_spec))

## Computing class balance and baseline random classifier accuracy

In [32]:
#labels = []
#for sample in train_dataset:
#    labels.append(sample.context['toxicity'].numpy()[0])
    

In [33]:
#len(labels)

In [34]:
#1-np.sum(labels)/len(labels)
# BASELINE --> 0.7361643415606316

## RANDOM BASELINE --> 0.7361

In [35]:
graph_tensor = next(iter(train_dataset))

## Graph sample can be inspected again as above

In [36]:
graph_tensor

GraphTensor(
  context=Context(features={'toxicity': <tf.Tensor: shape=(1,), dtype=tf.int64>, 'molecule_id': <tf.Tensor: shape=(1,), dtype=tf.string>}, sizes=[1], shape=(), indices_dtype=tf.int32),
  node_set_names=['atom'],
  edge_set_names=['bond'])

## Visualizing the sample graph

In [37]:
def draw_molecule(graph_tensor):
    """
    Plot the `Graph Tensor` representation of a molecule
    """
    (molecule_id, ) = graph_tensor.context['molecule_id'].numpy()
    (toxicity, ) = graph_tensor.context['toxicity'].numpy()

    sources = graph_tensor.edge_sets['bond'].adjacency.source.numpy()
    targets = graph_tensor.edge_sets['bond'].adjacency.target.numpy()

    pgvGraph = pgv.AGraph()
    pgvGraph.graph_attr['label'] = f'toxicity = {toxicity}\n\nmolecule_id = {molecule_id.decode()}'

    for edge in zip(sources, targets):
        pgvGraph.add_edge(edge)

    return Image(pgvGraph.draw(format='png', prog='dot'))

## Printing the sample graphs - use with GraphViz

In [38]:
#draw_molecule(graph_tensor)

## Batching the dataset

In [39]:
batch_size = 128
batched_train_dataset = train_dataset.batch(batch_size)

In [40]:
graph_tensor_item = next(iter(train_dataset))

In [41]:
graph_tensor_item

GraphTensor(
  context=Context(features={'toxicity': <tf.Tensor: shape=(1,), dtype=tf.int64>, 'molecule_id': <tf.Tensor: shape=(1,), dtype=tf.string>}, sizes=[1], shape=(), indices_dtype=tf.int32),
  node_set_names=['atom'],
  edge_set_names=['bond'])

In [42]:
graph_tensor_item.rank

0

In [43]:
graph_tensor_batch = next(iter(batched_train_dataset))

In [44]:
graph_tensor_batch

GraphTensor(
  context=Context(features={'toxicity': <tf.Tensor: shape=(128, 1), dtype=tf.int64>, 'molecule_id': <tf.Tensor: shape=(128, 1), dtype=tf.string>}, sizes=[[1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]], shape=(128,), indices_dtype=tf.int32),
  node_set_names=['atom'],
  edge_set_names=['bond'])

In [45]:
graph_tensor_batch.rank

1

In [46]:
graph_tensor_item.node_sets['atom']['atom_features'].shape

TensorShape([33, 27])

## After batching individual samples are stored as ragged tensors because each graph in the batch may contain a different number of nodes or edges: hence the node dimension is set to None

In [47]:
# Rsulting shape: (batch_size, None, atom_fetures_dimension)
graph_tensor_batch.node_sets['atom']['atom_features'].shape

TensorShape([128, None, 27])

In [48]:
graph_tensor_batch.edge_sets['bond']['bond_features'].shape

TensorShape([128, None, 12])

## GNN layers expect flat scalar graph data without division into separate tensors

## One needs to merge all the graphs in the batch into one graph data with batch_size disconnected components

## TF-GNN keeps track of these disconnected graphs automatically

In [49]:
scalar_graph_tensor = graph_tensor_batch.merge_batch_to_components()
scalar_graph_tensor.rank

0

## Now the number of nodes and edges in explicitly accessible again, with counts reflecting the sum of all vertices and edges in the batch graphs

In [50]:
scalar_graph_tensor.node_sets['atom']['atom_features'].shape

TensorShape([3368, 27])

In [51]:
scalar_graph_tensor.edge_sets['bond']['bond_features'].shape

TensorShape([7312, 12])

In [52]:
scalar_graph_tensor.context.features['toxicity']

<tf.Tensor: shape=(128,), dtype=int64, numpy=
array([0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1,
       0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])>

In [53]:
scalar_graph_tensor.context.features

{'toxicity': <tf.Tensor: shape=(128,), dtype=int64, numpy=
array([0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1,
       0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])>, 'molecule_id': <tf.Tensor: shape=(128,), dtype=string, numpy=
array([b'CC1=C(C/C=C(\\C)CCC[C@H](C)CCC[C@H](C)CCCC(C)C)C(=O)c2ccccc2C1=O',
       b'CC(=O)N1CCN(c2cnc3cc(C(F)(F)F)cc(NCc4cccc([N+](=O)[O-])c4)c3c2)CC1',
       b'COc1c(Nc2ncc(Cl)c(Nc3ccccc3S(=O)(=O)N(C)C)n2)ccc2c1CCCC(N1CCN(CCO)CC1)C2',
       b'CCCCCCCC/C=C\\CCCCCCCC(=O)O',
       b'CCN(CC)C(=O)c1ccc([C@H](c2cccc(NC(=O)OC)c2)N2CCN(Cc3cccnc3)CC2)cc1',
       b'FC(F)(F)c1cc(CO[C@H]2CCCN[C@H]2c2ccccc2)cc(C(F)(F)F)c1',
       b'Cc1ccc(Br)cc1', b'CN

## Preparing elements for the GNN

## A typical GCC first creates per-node features embeddig and then applies several message passing layers to the graph

## Initial graphs embeddig

In [54]:
def get_initial_map_features(hidden_size, activation='relu'):
    """
    Initial pre-processing layer for a GNN
    """
    def node_sets_fn(node_set, node_set_name):
        if node_set_name == 'atom':
            return tf.keras.layers.Dense(units=hidden_size, activation=activation)(node_set['atom_features'])

    def edge_sets_fn(edge_set, edge_set_name):
        if edge_set_name == 'bond':
            return tf.keras.layers.Dense(units=hidden_size, activation=activation)(edge_set['bond_features'])

    return tfgnn.keras.layers.MapFeatures(node_sets_fn=node_sets_fn,
                                          edge_sets_fn=edge_sets_fn,
                                          name='graph_embedding')
    
            

## This function replaces atom_features and bond_features with hidden states of the specified dimension

## The embedder layer instance

In [55]:
graph_embedding = get_initial_map_features(hidden_size=128)

In [None]:
embedded_graph = graph_embedding(scalar_graph_tensor)

In [57]:
scalar_graph_tensor.node_sets['atom'].features

{'atom_features': <tf.Tensor: shape=(3368, 27), dtype=float32, numpy=
array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 1.],
       [0., 1., 0., ..., 0., 0., 1.],
       [0., 1., 0., ..., 0., 0., 1.]], dtype=float32)>}

## After embedding we see new feature dimension and new name of the feature, the hidden_state

In [58]:
embedded_graph.node_sets['atom'].features

{'hidden_state': <tf.Tensor: shape=(3368, 128), dtype=float32, numpy=
array([[0.3377571 , 0.        , 0.07301895, ..., 0.        , 0.        ,
        0.        ],
       [0.32152247, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.32152247, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.35425305, 0.        , 0.08129819, ..., 0.01912414, 0.11002454,
        0.        ],
       [0.25396585, 0.        , 0.08569561, ..., 0.        , 0.        ,
        0.        ],
       [0.25396585, 0.        , 0.08569561, ..., 0.        , 0.        ,
        0.        ]], dtype=float32)>}

In [59]:
embedded_graph.edge_sets['bond'].features

{'hidden_state': <tf.Tensor: shape=(7312, 128), dtype=float32, numpy=
array([[0.07793769, 0.17098849, 0.        , ..., 0.        , 0.19007693,
        0.        ],
       [0.07793769, 0.17098849, 0.        , ..., 0.        , 0.19007693,
        0.        ],
       [0.07289648, 0.52462673, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.459872  , 0.        , ..., 0.        , 0.        ,
        0.2028085 ],
       [0.        , 0.459872  , 0.        , ..., 0.        , 0.        ,
        0.2028085 ],
       [0.        , 0.459872  , 0.        , ..., 0.        , 0.        ,
        0.2028085 ]], dtype=float32)>}

## We use the Graph Attention model (arxiv:1710.10903) as the stacked message passing network
https://github.com/tensorflow/gnn/tree/main/tensorflow_gnn/models/gat_v2

In [60]:
class MPNN(tf.keras.layers.Layer):
    """
    A basic stack of message-passing Graph Attention layers.
    """
    def __init__(self, hidden_size, hops, name='gat_mpnn', **kwargs):
        super().__init__(name=name, **kwargs)
        self.hidden_size = hidden_size
        self.hops = hops
        
        self.mp_layers = [self._mp_factory(name=f'message_passing{i}') for i in range(hops)]

    def _mp_factory(self, name):
        return gat_v2.GATv2GraphUpdate(num_heads=1,
                                       per_head_channels=self.hidden_size,
                                       edge_set_name='bond',
                                       sender_edge_feature=tfgnn.HIDDEN_STATE,
                                       name=name)

    def get_config(self):
        config = super().get_config()
        config.update({
            'hidden_size': self.hidden_size,
            'hops': self.hops
        })
        return config

    def call(self, graph_tensor):
        for layer in self.mp_layers:
            graph_tensor = layer(graph_tensor)
        return graph_tensor


In [61]:
mpnn = MPNN(hidden_size=128, hops=8)

## MPNN can process our initially embeded graph

In [62]:
hidden_graph = mpnn(embedded_graph)

## Upon evaluation our network maintains all graph structures in the dict form with appropriately transformed features

In [63]:
hidden_graph.node_sets['atom'].features

{'hidden_state': <tf.Tensor: shape=(3368, 128), dtype=float32, numpy=
array([[0.        , 0.        , 0.        , ..., 0.18465093, 0.10301366,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.14064904, 0.15337469,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.13251466, 0.16547535,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.20664921, 0.16788206,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.19097888, 0.16449249,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.1963811 , 0.1651372 ,
        0.        ]], dtype=float32)>}

In [64]:
hidden_graph.edge_sets['bond'].features

{'hidden_state': <tf.Tensor: shape=(7312, 128), dtype=float32, numpy=
array([[0.07793769, 0.17098849, 0.        , ..., 0.        , 0.19007693,
        0.        ],
       [0.07793769, 0.17098849, 0.        , ..., 0.        , 0.19007693,
        0.        ],
       [0.07289648, 0.52462673, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.459872  , 0.        , ..., 0.        , 0.        ,
        0.2028085 ],
       [0.        , 0.459872  , 0.        , ..., 0.        , 0.        ,
        0.2028085 ],
       [0.        , 0.459872  , 0.        , ..., 0.        , 0.        ,
        0.2028085 ]], dtype=float32)>}

## The model

In [65]:
def vanilla_mpnn_model(graph_tensor_spec, # Input tensor spec
                       init_states_fn, # The graph embedding layer
                       pass_messages_fn # The GAT network
                      ):
    """
    Creating GNN using functional API
    """
    graph_tensor = tf.keras.layers.Input(type_spec=graph_tensor_spec)
    embedded_graph = init_states_fn(graph_tensor)
    hidden_graph = pass_messages_fn(embedded_graph)
    return tf.keras.models.Model(inputs=graph_tensor, outputs=hidden_graph)

In [66]:
model = vanilla_mpnn_model(graph_tensor_spec=graph_spec,
                           init_states_fn=graph_embedding,
                           pass_messages_fn=mpnn)
model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [()]                      0         
                                                                 
 graph_embedding (MapFeatur  ()                        5248      
 es)                                                             
                                                                 
 gat_mpnn (MPNN)             ()                        396288    
                                                                 
Total params: 401536 (1.53 MB)
Trainable params: 401536 (1.53 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


## Wrapper for the model for specifying hyperparameters

In [67]:
def get_model_creation_fn(hidden_size, hops, activation='relu'):#, l2_coefficient=1e-3):
    """
    Return model constuctor with specified hyperparameters
    One could also try different hidden sizes for the node and edge data
    """
    def model_creation_fn(graph_tensor_spec):
        initial_map_features = get_initial_map_features(hidden_size=hidden_size, activation=activation)
        mpnn = MPNN(hidden_size=hidden_size, hops=hops)

        model = vanilla_mpnn_model(graph_tensor_spec=graph_tensor_spec,
                                   init_states_fn=initial_map_features,
                                   pass_messages_fn=mpnn)
        
        # Adding global L2 regularization loss
        #model.add_loss(lambda: tf.reduce_sum([tf.keras.regularizers.l2(l2=l2_coefficient)(weight) for weight in model.trainable_weights]))

        return model


    return model_creation_fn



In [68]:
mpnn_creation_fn = get_model_creation_fn(hidden_size=128, hops=8)

In [69]:
model = mpnn_creation_fn(graph_spec)
model.summary()

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [()]                      0         
                                                                 
 graph_embedding (MapFeatur  ()                        5248      
 es)                                                             
                                                                 
 gat_mpnn (MPNN)             ()                        396288    
                                                                 
Total params: 401536 (1.53 MB)
Trainable params: 401536 (1.53 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [70]:
model.summary()

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [()]                      0         
                                                                 
 graph_embedding (MapFeatur  ()                        5248      
 es)                                                             
                                                                 
 gat_mpnn (MPNN)             ()                        396288    
                                                                 
Total params: 401536 (1.53 MB)
Trainable params: 401536 (1.53 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


## GNN Training Task

## The TF-GNN Orchestator defines task protocol, which:
 * ### Adds readout and prediciton head - our model up to now processes graphs and has no final output
 * ### Adds loss
 * ### Defines metrics

## TF-GNN Orchestrator provides a prototypical task for binary graph classification, the *GraphBinaruClassifiation*

## We define AUROC metric for binary classifier

In [71]:
class AUROC(tf.keras.metrics.AUC):
    """
    AUROC metric computation for binary classification from logits

    y_true: true labels with shape (batch_size,)
    y_pred: logits with shape (batch_size, 2), over which we need to perform softmax and later argmax
    """
    # Metric is a stateful object so we define update state method
    def update_state(self, y_true, y_pred, sample_weight=None):
        #print(f'TRUE: {y_true}\n')
        #print(f'PRED: {y_pred}\n')
        
        #super().update_state(y_true, tf.math.softmax(y_pred, axis=-1)[:,1])
        super().update_state(y_true, tf.math.softmax(y_pred, axis=-1)[:, 0])

## Now the task definition: the philosophy is to take model body and attach a head specializing to a particular prediction we are interested in, in this case binary classifier.

## If we wanted to address some other ML quesiton we could define a different head in another task and use the same body.

In [72]:
class GraphBinaryClassification(runner.GraphBinaryClassification):
    """
    A GraphGraphBinaryClassification task with a hidden layer in prediction head and additional metrics.
    """
    def __init__(self, hidden_dim, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._hidden_dim = hidden_dim
        
    def adapt(self, model):
        # We need to pools nodes in analogy to flattening in CNN networks, so we can 
        # process them with a final Dense layer

        # Another version from 
        # https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md#the-big-picture-initialization-graph-updates-and-readout
        # pooled_features = tfgnn.keras.layers.Pool(tfgnn.CONTEXT, "mean", node_set_name="your_node_set")(graph)
        
        hidden_state = tfgnn.pool_nodes_to_context(model.output,
                                                   node_set_name=self._node_set_name,
                                                   reduce_type=self._reduce_type,
                                                   feature_name=self._state_name)

        hidden_state = tf.keras.layers.Dropout(0.2)(hidden_state)
        
        hidden_state = tf.keras.layers.Dense(units=self._hidden_dim, 
                                             activation='relu', name='hidden_layer')(hidden_state)
        
        hidden_state = tf.keras.layers.Dropout(0.2)(hidden_state)
                             
        logits = tf.keras.layers.Dense(units=self._units, name='logits')(hidden_state)
        #logits = tf.keras.layers.Dense(units=2, name='logits')(hidden_state)
        
        
        return tf.keras.Model(inputs=model.inputs, outputs=logits)

    #def metrics(self):
    #    # Concatenate metrics tuple
    #    return (*super().metrics(), AUROC(name='AUROC'))

In [73]:
def extract_labels(graph_tensor):
    """
    Extract the toxicity class label from the *GraphTensor* input data
    Return a pair compatible with *tf.keras.Model.fit* method
    """
    #print(graph_tensor)
    #print(graph_tensor.shape)
    return graph_tensor, graph_tensor.context['toxicity']
    

In [74]:
#task = runner.RootNodeBinaryClassification(
#    "nodes",
#    label_fn=runner.ContextLabelFn("label"))

In [75]:
label_fn = runner.ContextLabelFn(feature_name="toxicity")

#label_fn = runner.RootNodeLabelFn(node_set_name="context", feature_name="toxicity")

In [76]:
task = GraphBinaryClassification(hidden_dim=256, 
                                 node_set_name='atom',
                                label_fn=label_fn)
                                #label_feature_name='toxicity')#, num_classes=2)

In [77]:
task.losses()

<keras.src.losses.BinaryCrossentropy at 0x7ff127720370>

In [78]:
task.metrics()

(<tensorflow_gnn.runner.tasks.classification.FromLogitsPrecision at 0x7ff12019f100>,
 <tensorflow_gnn.runner.tasks.classification.FromLogitsRecall at 0x7ff1277208e0>,
 <keras.src.metrics.confusion_metrics.AUC at 0x7ff1205358e0>,
 <keras.src.metrics.confusion_metrics.AUC at 0x7ff1380c1d00>,
 <keras.src.metrics.accuracy_metrics.BinaryAccuracy at 0x7ff127696250>,
 <keras.src.losses.BinaryCrossentropy at 0x7ff1276fb280>)

In [79]:
classificaiton_model = task.adapt(model)

In [80]:
classificaiton_model.summary()

Model: "model_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [()]                      0         
                                                                 
 graph_embedding (MapFeatur  ()                        5248      
 es)                                                             
                                                                 
 gat_mpnn (MPNN)             ()                        396288    
                                                                 
 pool_nodes_to_context (TFG  (1, 128)                  0         
 NNOpLambda)                                                     
                                                                 
 dropout (Dropout)           (1, 128)                  0         
                                                                 
 hidden_layer (Dense)        (1, 256)                  3302

In [81]:
classificaiton_model(graph_tensor)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2670362]], dtype=float32)>

In [82]:
trainer = runner.KerasTrainer(strategy=tf.distribute.get_strategy(),
                              model_dir='model')

In [83]:
def get_sgd():
    return tf.keras.optimizers.experimental.SGD(
    learning_rate=0.0005,
    momentum=0.9)

In [84]:
!rm -rf ./model

## Training

### This updated code trains for a long time, as opposed to what the Kaggle source article shows, but the original code did not work with the updated TF-GNN

### Also the original optimized was failing, hence we switched to SGD and trained for a long time.

In [None]:
out = runner.run(
    train_ds_provider=train_dataset_provider,
    valid_ds_provider=valid_dataset_provider,
    #feature_processors=[extract_labels],
    model_fn=get_model_creation_fn(hidden_size=128, hops=8),
    task=task,
    trainer=trainer,
    epochs=1600,
    optimizer_fn=get_sgd,
    #optimizer_fn=tf.keras.optimizers.Adam,
    gtspec=graph_spec,
    global_batch_size=128
)



Epoch 1/1600


2024-01-20 22:03:03.714685: I external/local_xla/xla/service/service.cc:168] XLA service 0x7fefac170280 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-01-20 22:03:03.714707: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA RTX 6000 Ada Generation, Compute Capability 8.9
2024-01-20 22:03:03.723075: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
I0000 00:00:1705788183.745148  497220 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/1600
Epoch 3/1600
Epoch 4/1600
Epoch 5/1600
Epoch 6/1600
Epoch 7/1600
Epoch 8/1600
Epoch 9/1600
Epoch 10/1600
Epoch 11/1600
Epoch 12/1600
Epoch 13/1600
Epoch 14/1600
Epoch 15/1600
Epoch 16/1600
Epoch 17/1600
Epoch 18/1600
Epoch 19/1600
Epoch 20/1600
Epoch 21/1600
Epoch 22/1600
Epoch 23/1600
Epoch 24/1600
Epoch 25/1600
Epoch 26/1600
Epoch 27/1600
Epoch 28/1600
Epoch 29/1600
Epoch 30/1600
Epoch 31/1600
Epoch 32/1600
Epoch 33/1600
Epoch 34/1600
Epoch 35/1600
Epoch 36/1600
Epoch 37/1600
Epoch 38/1600
Epoch 39/1600
Epoch 40/1600
Epoch 41/1600
Epoch 42/1600
Epoch 43/1600
Epoch 44/1600
Epoch 45/1600
Epoch 46/1600
Epoch 47/1600
Epoch 48/1600
Epoch 49/1600
Epoch 50/1600
Epoch 51/1600
Epoch 52/1600

In [None]:
%load_ext tensorboard


In [None]:
%tensorboard --logdir model --bind_all

In [None]:
0.343/1.178