## Pytorch-geometric (PyG) tests

In [None]:
import torch

## Datasets
Basic graph data definition in PyG

### Data points
Example of graph implementation in PyG as data points (`Data` class), i.e. elements of a graph dataset.

In [None]:
from torch_geometric.data import Data

# edge_index contains indexes of nodes with incident edges in the form [2, num_edges],
# one row for the staring node index, one row for the ending node index of each edge.
# In the following example edge_index is encoding a directed graph with three nodes
# (0,1,2) and four edges (0->1, 1->0, 1->2, 2->1).   
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

# Data x parameters encodes graph node features. In the following example x assign a 
# 1-dimensional feature vector to each node (x0= -1, x1= 0, x2= 1)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

# Having the edge_index and the feature matrix we can construct a PyG data point (i.e
# a graph) with the following.
data = Data(x=x, edge_index=edge_index)
# NOTE: default toString shows only the size of x and edge_index  
print(data)

### Implemented Datasets
How to import one of the graph datasets already implemented in PyG, using the `Dataset` class. 

In [9]:
from torch_geometric.datasets import BAShapes

# the constructor from the library return an instance of the implemented dataset chosen.
# In this example, BAShapes() returns a Barabasi-albert (BA) graph enhanced with some motifs
# (300 nodes and a set of 80 "house"-structured graphs connected to it), generated following
# the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper.
dataset = BAShapes(connection_distribution="random")
print(f"[dataset]> ...loading dataset '{dataset}' from PyG")

# a Dataset object exposes some attributes abuot the data 
print("\t#entries:      ", len(dataset))
print("\t#classes:      ", dataset.num_classes)
print("\t#node_features:", dataset.num_node_features)
print("\t#edge_features:", dataset.num_edge_features)

# a dataset entry (i.e. a graph) is retrieved as a Data object (i.e. a data point)
graph = dataset[0]
print(f"\n[dataset]> {dataset} dataset graph...")
print("\t->", graph)
print("\t#nodes:", graph.num_nodes)
print("\t#edges:", graph.num_edges)

print(graph.node_stores[0]["x"].size())
print(graph)

[dataset]> ...loading dataset 'BAShapes()' from PyG
	#entries:       1
	#classes:       4
	#node_features: 10
	#edge_features: 0

[dataset]> BAShapes() dataset graph...
	-> Data(x=[700, 10], edge_index=[2, 3922], y=[700], expl_mask=[700], edge_label=[3922])
	#nodes: 700
	#edges: 3922
torch.Size([700, 10])
torch.Size([2, 3922])
