Here is a simple example of the `gtr_rewiring` module.

`gtr_rewiring` provides a set of transforms to add edges to a `torch_geometric` graph. Our recommended usage is to use the transform `PrecomputeGTREdges` as a `pre_transform` for a dataset and `AddPrecomputedGTREdges` as a `transform`. `PrecomputeGTREdges` computes a set of edges using the GTR algorithm for each graph in the dataset; however, this does not actually add the edges to the graph. `AddPrecomputedGTREdges` adds the best subset of the precomputed edges to the graph.  

In [1]:
import torch_geometric.transforms as T
from torch_geometric.datasets import TUDataset
from gtr import PrecomputeGTREdges, AddPrecomputedGTREdges 

In [2]:
# precompute 30 edges with the gtr algorithm
pre_transform = T.Compose([PrecomputeGTREdges(num_edges=30)])
# add 20 of the precomputed edges to the graph
transform = T.Compose([AddPrecomputedGTREdges(num_edges=20)])
# load the dataset
dataset = TUDataset(
    root="/tmp/",
    name="MUTAG",
    transform=transform,
    pre_transform=pre_transform
)

The precomputed edges are stored in the `precomputed_gtr_edges` attribute of each graph. We can check that the correct number of edges has been successfully precomputed

In [3]:
# Check that 60 edges have been precomputed for each graph.
# (AddPrecomputedGTREdges adds both direction of an edge,
# which is why we check that 60, not 30, edges have been precomputed.)
if all([
    hasattr(data, "precomputed_gtr_edges") and data.precomputed_gtr_edges.shape[1] == 60
    for data in dataset
]):
    print("Edges succesfully precomputed!")

Edges succesfully precomputed!


We can now verify that edges have actually been added to the `edge_index`.

In [4]:
# Load the dataset 
dataset_wo_edges = TUDataset(
    root="/tmp/",
    name="MUTAG",
    pre_transform=pre_transform
)
# Check that 40 edges have been added to each graph in the dataset
if all([ 
    (data.edge_index.shape[1]-data_wo_edges.edge_index.shape[1]) == 40
    for data, data_wo_edges 
    in zip(dataset, dataset_wo_edges) 
]):
    print("Edges succesfully added!")

Edges succesfully added!
