# Building a TF Data Loader
This notebook shows how we go from serialized TFRecords to training batches. TF Data Loaders are created by defining a set of operations as a stream. 

In [1]:
from graphsage.mpnn.data import *
import tensorflow as tf
import os
tf.enable_eager_execution()

In [2]:
data_file = os.path.join('..', 'data', 'output', 'water_clusters.proto')

## Part 1: Loading Records from Disk
Our records are stored as serialized protobuf records. Origin of the stream of data will be to read from this file

In [3]:
loader = tf.data.TFRecordDataset(data_file)
loader

<TFRecordDatasetV1 shapes: (), types: tf.string>

Note that this loader produces a dataset of strings, which are the serialized data objects

In [4]:
next(iter(loader))

<tf.Tensor: id=18, shape=(), dtype=string, numpy=b"\n\xcb\x02\n\x12\n\x06energy\x12\x08\x12\x06\n\x04\xd4\xfc\x94\xc2\n\x11\n\x08n_waters\x12\x05\x1a\x03\n\x01\t\n\x0f\n\x06n_atom\x12\x05\x1a\x03\n\x01\x1b\n\x0f\n\x06n_bond\x12\x05\x1a\x03\n\x01<\n'\n\x04atom\x12\x1f\x1a\x1d\n\x1b\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\nH\n\x04bond\x12@\x1a>\n<\x00\x00\x00\x00\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\n\x8c\x01\n\x0cconnectivity\x12|\x1az\nx\x00\x01\x00\x02\x00\x0b\x01\x00\x01\x06\x02\x00\x03\x04\x03\x05\x03\x07\x04\x03\x05\x03\x05\x18\x06\x01\x06\x07\x06\x08\x07\x03\x07\x06\x08\x06\x08\x0f\t\n\t\x0b\t\x13\n\t\x0b\x00\x0b\t\x0c\r\x0c\x0e\x0c\x16\r\x0c\r\x18\x0e\x0c\x0e\x0f\x0f\x08\x0f\x0e\x0f\x10\x0f\x11\x10\x0f\x10\x12\x11\x0f\x11\x15\x12

For performance reasons, we are going to form these records into batches first before parsing.

In [5]:
loader = tf.data.TFRecordDataset(data_file).shuffle(128).batch(2)
loader

<DatasetV1Adapter shapes: (?,), types: tf.string>

In [6]:
next(iter(loader))

<tf.Tensor: id=43, shape=(2,), dtype=string, numpy=
array([b'\n\x8a\x03\n\x12\n\x06energy\x12\x08\x12\x06\n\x04F\xf7\xc6\xc2\n\x11\n\x08n_waters\x12\x05\x1a\x03\n\x01\x0b\n\x0f\n\x06n_atom\x12\x05\x1a\x03\n\x01!\n\x0f\n\x06n_bond\x12\x05\x1a\x03\n\x01N\n-\n\x04atom\x12%\x1a#\n!\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\x00\x01\x01\nZ\n\x04bond\x12R\x1aP\nN\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\n\xb3\x01\n\x0cconnectivity\x12\xa2\x01\x1a\x9f\x01\n\x9c\x01\x00\x01\x00\x02\x00\x13\x01\x00\x01\x1b\x02\x00\x02\x0c\x03\x04\x03\x05\x03\x08\x03\x1a\x04\x03\x04\x0f\x05\x03\x05\x0c\x06\x07\x06\x08\x06\r\x06 \x07\x06\x07\x12\x08\x03\x08\x06\t\n\t\x0b\

We now have a two member batch of arrays

Our next step is to convert the batches of records into a set of tensors.

To do so, you must define which elements of the protobuf message you would like to read from the object and define their types.

In [7]:
def parse_records(example_proto):
    """Parse data from the TFRecord"""
    features = {
        'energy': tf.io.FixedLenFeature([], tf.float32, default_value=np.nan),
        'n_atom': tf.io.FixedLenFeature([], tf.int64),
        'n_bond': tf.io.FixedLenFeature([], tf.int64),
        'connectivity': tf.io.VarLenFeature(tf.int64),
        'atom': tf.io.VarLenFeature(tf.int64),
        'bond': tf.io.VarLenFeature(tf.int64),
    }
    return tf.io.parse_example(example_proto, features)

We apply this function to the data chain

In [8]:
loader = tf.data.TFRecordDataset(data_file).shuffle(128).batch(2).map(parse_records)
loader

<DatasetV1Adapter shapes: {atom: (?, ?), bond: (?, ?), connectivity: (?, ?), energy: (?,), n_atom: (?,), n_bond: (?,)}, types: {atom: tf.int64, bond: tf.int64, connectivity: tf.int64, energy: tf.float32, n_atom: tf.int64, n_bond: tf.int64}>

We now have Tensor objects!

## Preprocessing to Make MPNN-compatible batches
Tensorflow now can understand the data types, but these data are not yet in a form we can use in our MPNN.

In [9]:
next(iter(loader))

{'atom': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x21f787f3d48>,
 'bond': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x21f7afb5c08>,
 'connectivity': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x21f7b049388>,
 'energy': <tf.Tensor: id=108, shape=(2,), dtype=float32, numpy=array([-98.59649, -98.24457], dtype=float32)>,
 'n_atom': <tf.Tensor: id=109, shape=(2,), dtype=int64, numpy=array([33, 33], dtype=int64)>,
 'n_bond': <tf.Tensor: id=110, shape=(2,), dtype=int64, numpy=array([78, 76], dtype=int64)>}

Our big issue is that the `SparseTensor` objects cannot be used in many Tensorflow operations. We need to convert them to Dense layers.

In [10]:
def prepare_for_batching(dataset):
    """Make the variable length arrays into RaggedArrays.
    
    Allows them to be merged together in batches"""
    for c in ['atom', 'bond', 'connectivity']:
        expanded = tf.expand_dims(dataset[c].values, axis=0, name=f'expand_{c}')
        dataset[c] = tf.RaggedTensor.from_tensor(expanded)
    return dataset

In [11]:
loader = tf.data.TFRecordDataset(data_file).shuffle(128).batch(2).map(parse_records).map(prepare_for_batching)
loader

<DatasetV1Adapter shapes: {atom: (?, ?), bond: (?, ?), connectivity: (?, ?), energy: (?,), n_atom: (?,), n_bond: (?,)}, types: {atom: tf.int64, bond: tf.int64, connectivity: tf.int64, energy: tf.float32, n_atom: tf.int64, n_bond: tf.int64}>

In [12]:
next(iter(loader))

{'atom': <tf.RaggedTensor [[0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1]]>,
 'bond': <tf.RaggedTensor [[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]>,
 'connectivity': <tf.RaggedTensor [[0, 1, 0, 2, 0, 4, 0, 14, 0, 29, 1, 0, 1, 18, 2, 0, 3, 4, 3, 5, 3, 10, 4, 0, 4, 3, 5, 3, 5, 15, 6, 7, 6, 8, 6, 22, 6, 31, 7, 6, 8, 6, 8, 27, 9, 10, 9, 11, 9, 23, 9, 28, 10, 3, 10, 9, 11, 9, 12, 13, 12, 14, 12, 32, 13, 12, 14, 0, 14, 12, 15, 5,

We now have the data close to the form we need it, minus a few things:

- We can't easily know which "node" corresponds to which training entry because we have stuck multiple graphs into the same batch
- The `connectivity` array is the wrong shape. It is a 1D instead of Nx2 array
- The node ids in the connectivity arrays are incorrect. Since we have merged multiple graphs, the node 0 of the second graph is no longer at position 0 in the `atom` array

In [13]:
def combine_graphs(batch):
    """Combine multiple graphs into a single network"""

    # Compute the mappings from bond index to graph index
    batch_size = tf.size(batch['n_atom'], name='batch_size')
    mol_id = tf.range(batch_size, name='mol_inds')
    batch['node_graph_indices'] = repeat(mol_id, batch['n_atom'], axis=0)
    batch['bond_graph_indices'] = repeat(mol_id, batch['n_bond'], axis=0)

    # Reshape the bond, connectivity, and node lists
    for c in ['atom', 'bond', 'connectivity']:
        batch[c] = batch[c].flat_values

    # Reshape the connectivity matrix to (None, 2)
    batch['connectivity'] = tf.reshape(batch['connectivity'], (-1, 2))

    # Denote the shapes for the atom and bond matrices
    #  Only an issue for 1.14, which cannot infer them it seems
    for c in ['atom', 'bond']:
        batch[c].set_shape((None,))

    # Compute offsets for the connectivity matrix
    offset_values = tf.cumsum(batch['n_atom'], exclusive=True)
    offsets = repeat(offset_values, batch['n_bond'], name='offsets', axis=0)
    batch['connectivity'] += tf.expand_dims(offsets, 1)

    return batch

In [14]:
loader = tf.data.TFRecordDataset(data_file).shuffle(128).batch(2).map(parse_records).map(prepare_for_batching).map(combine_graphs)
loader

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


<DatasetV1Adapter shapes: {atom: (?,), bond: (?,), connectivity: (?, 2), energy: (?,), n_atom: (?,), n_bond: (?,), node_graph_indices: (?,), bond_graph_indices: (?,)}, types: {atom: tf.int64, bond: tf.int64, connectivity: tf.int64, energy: tf.float32, n_atom: tf.int64, n_bond: tf.int64, node_graph_indices: tf.int32, bond_graph_indices: tf.int32}>

This operation converts our vectors into the right shape now. And, because we did all of these operations in Tensorflow, we can delegate these operations to the GPU and use Tensorflow's automated parallelism

In [15]:
next(iter(loader))

{'atom': <tf.Tensor: id=645, shape=(63,), dtype=int64, numpy=
 array([0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0,
        1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1,
        1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1],
       dtype=int64)>,
 'bond': <tf.Tensor: id=646, shape=(148,), dtype=int64, numpy=
 array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)>,
 'connectivity': <tf.Tensor: id=648, shape=(148, 2), dtype=int64, numpy=
 array([[ 0,  1],
        [ 0,  2],
        [ 0, 1