# Tutorial 3: Collation and GNNs

In the previous tutorials, we saw how to use `readers` to pull data from disk, `transforms` to manipulate that data, and `datasets` to put it all together for one instance of data.  `dataloaders` help batch the data, but we had some implicit assumptions about how to put that data together: stack it onto a new axis at the front of the batch.

For some applications, like graph neural networks, that doesn't work.  We need a new way to *collate* the data together.  In this tutorial we'll see how to take the graph data generated for tutorial 2, find the k-Nearest-Neighbors for each point, and use those neighbors to build a batched graph for frameworks like PyTorch Geometric (used in many `physicsnemo` graph models).

> NOTE: you need torch_geometric for this example

In [1]:
import time
from pathlib import Path
from typing import Any, Sequence

import torch
from torch_geometric.data import Batch as PyGBatch
from torch_geometric.data import Data as PyGData

# Import core datapipe components
from physicsnemo.datapipes import DataLoader, Dataset
from physicsnemo.datapipes.collate import Collator
from physicsnemo.datapipes.readers import ZarrReader
from physicsnemo.datapipes.transforms import (
    KNearestNeighbors,
    SubsamplePoints,
)



## Section 1: Dynamically compute kNN on the points, on the fly

Many graphs come with edge information already available.  Other data, such as point cloud data, comes unstructured.  We can generate a structure on the fly with kNN operations, here backed by optimized tree searches with `physicsnemo`'s knn function.  On CPU, this will call `scipy'`s KDTree operation, while on GPU this will dispatch to `RAPIDs` neighbor utilities.

> NOTE: You will want either scipy or cuml installed for this operation on CPU or GPU, respectively.  Otherwise, `knn` will fall back to a brute force implementation and run out of memory!

In [2]:
import scipy

In [3]:
pointcloud_data_path = "./output/pointcloud_data/"

reader = ZarrReader(path=pointcloud_data_path, group_pattern="*.zarr")
data, metadata = reader[0]

print(f"Loaded sample with {data['coords'].shape[0]} points")
print(f"Fields: {list(data.keys())}")
print()

# Create and apply the KNN edge transform
knn_transform = KNearestNeighbors(
    points_key="coords",
    queries_key="coords",  # Apply the kNN to itself.
    k=8,  # It will find the 8 nearest edges
    extract_keys=["features"],
    drop_first_neighbor=False,  # Because we're applying this on itself, the "first" closest neighbor is always the point itself.  You could drop this in the selections
)
print(f"Transform: {knn_transform}")
print()

data_with_edges = knn_transform(data)

edge_index = "neighbors_indices"

print("After transform:")
print(f"  Fields: {list(data_with_edges.keys())}")
print(f"  edge_index shape: {data_with_edges[edge_index].shape}")
print()

# Verify graph structure
n_nodes = data_with_edges["coords"].shape[0]
n_edges = data_with_edges[edge_index].shape[1]

print(f"Graph structure:")
print(f"  Nodes: {n_nodes}")
print(f"  Edges / node: {n_edges}")
print()

Loaded sample with 65795 points
Fields: ['features', 'coords']

Transform: KNearestNeighbors(points_key=coords, queries_key=coords, k=8)

After transform:
  Fields: ['features', 'coords', 'neighbors_indices', 'neighbors_distances', 'neighbors_coords', 'neighbors_features']
  edge_index shape: torch.Size([65795, 8])

Graph structure:
  Nodes: 65795
  Edges / node: 8



Note that the `kNearestNeighbors` transform is able to automatically apply the selection to not just generate the indicies, but it can also compute the distances and select the coordinates.  Further, you can pass in other tensor keys and it will apply the selection for those features too:

In [4]:
print(data_with_edges["neighbors_features"].shape)

torch.Size([65795, 8, 8])


## Combining Batches of Data

If we want to have a batch size greater than 1 with this data, we can build a custom collator for the `physicsnemo` DataLoader:

In [5]:
class PyGCollator(Collator):
    """
    Collator that batches graphs using PyTorch Geometric's built-in batching.

    This collator converts each sample to a PyG Data object, then uses
    `Batch.from_data_list()` to handle all the complexity of graph batching:
    - Node features are concatenated: (N1 + N2 + ... + Nb, F)
    - Edge indices are automatically offset and concatenated
    - A `batch` tensor tracks which nodes belong to which graph

    Example:
        Graph 0: 100 nodes, edges [[0,1,2], [1,2,0]]
        Graph 1: 150 nodes, edges [[0,1], [1,0]]

        Batched (handled automatically by PyG):
        - nodes: (250, F)
        - edge_index: [[0,1,2,100,101], [1,2,0,101,100]]  # Graph 1 offset by 100
        - batch: [0]*100 + [1]*150
    """

    def __init__(
        self,
        edge_index_key: str = "edge_index",
        collate_metadata: bool = False,
    ) -> None:
        """
        Initialize the PyG-style collator.

        Args:
            edge_index_key: Key for edge indices in the input data.
                Expected shape is [num_nodes, k] from KNN, which will be
                converted to PyG's [2, num_edges] format.
        """
        self.collate_metadata = collate_metadata
        self.edge_index_key = edge_index_key

    @staticmethod
    def knn_to_edge_index(knn_indices: torch.Tensor) -> torch.Tensor:
        """
        Convert KNN indices to PyG edge_index format.

        Args:
            knn_indices: Tensor of shape [num_nodes, k] where each row contains
                the k nearest neighbor indices for that node.

        Returns:
            edge_index: Tensor of shape [2, num_nodes * k] in PyG COO format,
                where edge_index[0] is source nodes and edge_index[1] is target nodes.
        """
        num_nodes, k = knn_indices.shape
        # Source nodes: each node index repeated k times
        source = torch.arange(num_nodes, device=knn_indices.device).repeat_interleave(k)
        # Target nodes: flatten the KNN indices
        target = knn_indices.reshape(-1)
        return torch.stack([source, target], dim=0)

    def __call__(
        self, samples: Sequence[tuple[dict, dict[str, Any]]]
    ) -> tuple[PyGBatch, list[dict[str, Any]]]:
        """
        Collate graphs into a batched PyG Batch object.

        Args:
            samples: Sequence of (TensorDict/dict, metadata) tuples.

        Returns:
            Tuple of (PyG Batch, list of metadata dicts).
        """
        if not samples:
            raise ValueError("Cannot collate empty sequence of samples")

        # Separate data and metadata
        data_list = [data for data, _ in samples]

        # Convert each sample to a PyG Data object
        pyg_data_list = []
        for data in data_list:
            # Build kwargs for PyG Data, renaming edge_index_key to 'edge_index'
            data_kwargs = {}
            for key in data.keys():
                tensor = data[key]
                if key == self.edge_index_key:
                    # Convert from KNN format [num_nodes, k] to PyG format [2, num_edges]
                    data_kwargs["edge_index"] = self.knn_to_edge_index(tensor)
                else:
                    data_kwargs[key] = tensor

            pyg_data_list.append(PyGData(**data_kwargs))

        # Use PyG's built-in batching - handles edge index offsetting automatically
        batched_data = PyGBatch.from_data_list(pyg_data_list)

        if self.collate_metadata:
            metadata_list = [meta for _, meta in samples]
            return batched_data, list(metadata_list)
        else:
            return batched_data

    def __repr__(self) -> str:
        return f"PyGCollator(edge_index_key={self.edge_index_key})"

Let's convert the reader + kNN into a dataset:


In [6]:
dataset = Dataset(
    reader=reader,
    transforms=knn_transform,
)

In [7]:
# Print the first two data items:
dataset[0]

(TensorDict(
     fields={
         coords: Tensor(shape=torch.Size([65795, 3]), device=cpu, dtype=torch.float32, is_shared=False),
         features: Tensor(shape=torch.Size([65795, 8]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_coords: Tensor(shape=torch.Size([65795, 8, 3]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_distances: Tensor(shape=torch.Size([65795, 8]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_features: Tensor(shape=torch.Size([65795, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_indices: Tensor(shape=torch.Size([65795, 8]), device=cpu, dtype=torch.int64, is_shared=False)},
     batch_size=torch.Size([]),
     device=cpu,
     is_shared=False),
 {'source_file': '/Users/coreya/physicsnemo/examples/minimal/datapipes/output/pointcloud_data/sample_000000.zarr',
  'source_filename': 'sample_000000.zarr',
  'index': 0})

In [8]:
dataset[1]

(TensorDict(
     fields={
         coords: Tensor(shape=torch.Size([50860, 3]), device=cpu, dtype=torch.float32, is_shared=False),
         features: Tensor(shape=torch.Size([50860, 8]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_coords: Tensor(shape=torch.Size([50860, 8, 3]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_distances: Tensor(shape=torch.Size([50860, 8]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_features: Tensor(shape=torch.Size([50860, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
         neighbors_indices: Tensor(shape=torch.Size([50860, 8]), device=cpu, dtype=torch.int64, is_shared=False)},
     batch_size=torch.Size([]),
     device=cpu,
     is_shared=False),
 {'source_file': '/Users/coreya/physicsnemo/examples/minimal/datapipes/output/pointcloud_data/sample_000001.zarr',
  'source_filename': 'sample_000001.zarr',
  'index': 1})

Apply the collation:

In [9]:
collator = PyGCollator(edge_index_key="neighbors_indices")

batch_gnn_inputs = collator([dataset[0], dataset[1]])

  return sum([v.num_nodes for v in self.node_stores])
  repeats = [store.num_nodes or 0 for store in stores]


In [10]:
batch_gnn_inputs

DataBatch(edge_index=[2, 933240], features=[116655, 8], coords=[116655, 3], neighbors_distances=[116655, 8], neighbors_coords=[116655, 8, 3], neighbors_features=[116655, 8, 8], batch=[116655], ptr=[3])

Of course, you can absolutely build this all in one pass:

In [11]:
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator,
    collate_metadata=False,
)

In [12]:
for idx, batch_gnn in enumerate(dataloader):
    print(f"{idx}: {batch_gnn}")

0: DataBatch(edge_index=[2, 2543888], features=[317986, 8], coords=[317986, 3], neighbors_distances=[317986, 8], neighbors_coords=[317986, 8, 3], neighbors_features=[317986, 8, 8], batch=[317986], ptr=[5])
1: DataBatch(edge_index=[2, 2685848], features=[335731, 8], coords=[335731, 3], neighbors_distances=[335731, 8], neighbors_coords=[335731, 8, 3], neighbors_features=[335731, 8, 8], batch=[335731], ptr=[5])
2: DataBatch(edge_index=[2, 2306776], features=[288347, 8], coords=[288347, 3], neighbors_distances=[288347, 8], neighbors_coords=[288347, 8, 3], neighbors_features=[288347, 8, 8], batch=[288347], ptr=[5])
3: DataBatch(edge_index=[2, 1951112], features=[243889, 8], coords=[243889, 3], neighbors_distances=[243889, 8], neighbors_coords=[243889, 8, 3], neighbors_features=[243889, 8, 8], batch=[243889], ptr=[5])
4: DataBatch(edge_index=[2, 2364192], features=[295524, 8], coords=[295524, 3], neighbors_distances=[295524, 8], neighbors_coords=[295524, 8, 3], neighbors_features=[295524, 8,

# Putting it all together

Physicsnemo's datapipes abstraction, as we've seen in these first 3 tutorials, offers flexibility at multiple levels.  You can individually configure how to _read_ data, how to manipulate individual data samples, and how to merge data samples into a batch.  You can also extend all of these tools with custom implementations: readers have just a few methods to override, and transforms only need to implement one function, and we saw in this tutorial how to apply custom collation logic to output onto PyG graphs.

In all of this, we've traded flexibility for configuration verbosity: what can be accomplished in in a few lines of python for a one off research script takes more configuration in physicsnemo, though with an added benefit of checks and testing of your datapipe.  However, there is an additional technique for deploying your datapipe: configure entirely in `hydra` and instantiate the objects directly in one line of python.  We'll see that next.