## **Advanced Mini-Batching**
- In the image or language domain, this procedure is typically achieved by rescaling or padding each example into a set to equally-sized shapes
- Since graphs are one of the most general data structures that can hold any number of nodes or edges
- Adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs)
- PyTorch Geometric automatically takes care of batching multiple graphs into a single giant graph with the help of the **torch_geometric.data.DataLoader** class. 
    + Internally, torch_geometric.data.DataLoader is just a regular PyTorch DataLoader that overwrites its collate() functionality
    + i.e., the definition of how a list of examples should be grouped together.

### **Pairs of Graphs**
- In case you want to store multiple graphs in a single torch_geometric.data.Data object,
- e.g., for applications such as graph matching, you need to ensure correct batching behaviour across all those graphs.
- For example, consider storing two graphs, a source graph $G_s$ and a target graph $G_t$ in a torch_geometric.data.Data, e.g.:

In [3]:
import torch
from torch_geometric.data import DataLoader, Data

class PairData(Data):
    def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.x_t = x_t
        
    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super(PairData, self).__inc__(key, value)

In [7]:
# Example Data
edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.
print("- Source graph edge list size : {}".format(edge_index_s.shape))
print("- Source graph node attr size : {}".format(x_s.shape))
print("- Target graph edge list size : {}".format(edge_index_t.shape))
print("- Target graph node attr size : {}".format(x_t.shape))

- Source graph edge list size : torch.Size([2, 4])
- Source graph node attr size : torch.Size([5, 16])
- Target graph edge list size : torch.Size([2, 3])
- Target graph node attr size : torch.Size([4, 16])


In [8]:
data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
print("- Paired Data : {}".format(data))

- Paired Data : PairData(edge_index_s=[2, 4], edge_index_t=[2, 3], x_s=[5, 16], x_t=[4, 16])


In [9]:
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

In [12]:
print(batch)
print(batch.edge_index_s)
print(batch.edge_index_t)

Batch(edge_index_s=[2, 8], edge_index_t=[2, 6], x_s=[10, 16], x_t=[8, 16])
tensor([[0, 0, 0, 0, 5, 5, 5, 5],
        [1, 2, 3, 4, 6, 7, 8, 9]])
tensor([[0, 0, 0, 4, 4, 4],
        [1, 2, 3, 5, 6, 7]])


### **Bipartite Graphs**
- In general, the number of nodes for each node type do not need to match, resulting in a non-quadratic adjacency matrix

In [14]:
class BipartiteData(Data):
    def __init__(self, edge_index, x_s, x_t):
        super(BipartiteData, self).__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t
        
    def __inc__(self, key, value):
        if key == 'edge_index':
            return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
        else:
            return super(BipartiteData, self).__inc__(key, value)

In [18]:
edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
print("- Source graph nodes : {}".format(x_s.shape))
print("- Target graph nodes : {}".format(x_t.shape))
print("- Bipartite graph data : {}".format(data))

- Source graph nodes : torch.Size([2, 16])
- Target graph nodes : torch.Size([3, 16])
- Bipartite graph data : BipartiteData(edge_index=[2, 4], x_s=[2, 16], x_t=[3, 16])


In [23]:
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print("- Bipartite graph batch example : {}".format(batch))

- Bipartite graph batch example : Batch(batch=[6], edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16])


In [24]:
print(batch.edge_index)

tensor([[0, 0, 1, 1, 2, 2, 3, 3],
        [0, 1, 1, 2, 3, 4, 4, 5]])
