# PART I: data as PyG Dataset

Source: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html

Example of saving a graph as Dataset: https://medium.com/cj-express-tech-tildi/first-timers-guide-to-pytorch-geometric-part-1-the-basic-1b6006e1f4db

We can transform a graph into a PyG Datasets in order to automatically load raw data, process and save the graph in PyG format, which will be later loaded and fed into the DataLoader for the NN. (See Remark section: saving as dataset is not a mandatory step.)

Two abstract classes are available for datasets:
- and torch_geometric.data.InMemoryDataset: inherits from torch_geometric.data.Dataset and should be used if the whole dataset **fits into CPU memory (smaller dataset)**
- torch_geometric.data.Dataset: to be used if data do **not fit into CPU memory (large dataset)**

According to the size of the data, a class that inherits from either InMemoryDataset or from Dataset shall be implemented with some functions (virtual in the base class).

### InMemoryDataset (small dataset)

The dataset class must inherit from torch_geometric.data.InMemoryDataset and the following methods must be implemented:
- *raw_file_names()*: list of files in the raw_dir which needs to be found in order to skip the download. This file stores the data in raw format, which requires processing and storage in the processed_file_names.
- *processed_file_names()*: list of files in the processed_dir which needs to be found in order to skip the processing. This file stores the processed data (ideally ready for the Machine Learning model).
- *download()*: downloads raw data into raw_dir.
- *process()*: processes the raw data and saves it into the processed_dir. This is the **most important function to implement**.

Furthermore, the init class can receive the following arguments (**None by default**), to be passed to the super().___init__():
- transform: dynamically transforms the data object before accessing (so it is best used for data augmentation)
- pre_transform: applies the transformation before saving the data objects to disk (so it is best used for heavy precomputation which needs to be only done once)
- pre_filter: manually filters out data objects before saving.

One unique characteristic of InMemoryDataset is the call of **collate()**: this functions transforms a Python list of Data or HeteroData objects <u>to the internal graph data storage format of InMemoryDataset</u>. This object is *data*, who cna be used with *slices* to reconstruc single examples of this object.
Saving a huge dataset file is time consuming. 

Instead of saving the data structure directly, pytorch geometric separated data and a dictionary to reconstruct the dataset as:

<code>data, slices = self.collate(data_list)</code>

then saves the dataset with 

<code>torch.save((data, slices), self.processed_paths[0])</code>.

Furthermore, we need to load these two objects in the constructor into the properties self.data and self.slices with the *standard line of code*:
<code>self.data, self.slices = torch.load(self.processed_paths[0])</code>.

The library torch_geometric.data provides some tools for downloading data.


In [1]:
import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnInMemoryDataset(InMemoryDataset):
    # init: here the transform, pre_transform and pre_filter functions can be passed
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        
        # super() function to inherit methods from the base class
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        # process the data
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # collate and save data
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

### Dataset (large dataset)

If the dataset does not fit the CPU memory, then the class **torch_geometric.data.Dataset** shall be used as **base class**. In addition to <code>raw_file_names()</code>, <code>processed_file_names()</code>, <code>download()</code> and <code>process()</code>, the methods <code>len()</code> (returns the size of the dataset) and <code>get()</code> (logic to load a single graph, requires an index to access it) shall be implemented.

Unlike InMemoryDataset, **collate() is not used**.


In [1]:
import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

## Remarks on PyG Dataset

#### Usage of PyG's <code>Dataset</code> is not mandatory
We use PyG's Dataset only to process and save the dataset into a file. If we don't need to save the dataset on the disk, then we can directly pass the data into the <code>Dataloader</code> to interface with the Machine Learning model like:

<code>
    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader
    data_list = [Data(...), ..., Data(...)]
    loader = DataLoader(data_list, batch_size=32)
</code>

#### Execution of download() / process() can be skipped by not implementing these functions

# Part II: GNN

Source: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html 

One of the first GNN uses the **message passing** approach.

Let **x**_*i* be the node feature vector of the node *i*, **e**_*ij* the (optional) edge feature vector connecting the nodes *i* and *j*, then the node vector update is:

![image](images/message_passing_general.png)

- $\phi$: message passing function that aggregates the info from the neighborhood **N** into one single feature vector
- *O*: aggregation function (must be **permutation invariant**)
- $\gamma$: update function

**Permutation invariant**: a function is permutation invariant if *f*(**x**1, **x**2, **x**3) = *f*(**x**2, **x**3, **x**1) = *f*(**x**3, **x**2, **x**1) = ..
This is important becasue **graphs do not have ordering** (first/last node, laft/right node, etc).

In pseudocode format, the above equation can be interpreted in a simplified version as:

**x**_i^k = UPDATE^k(**x**_i^(k-1), AGGREGATE((MESSAGE^k(**x**_i^(k-1), **x**_j^(k-1), **e**_ij))

#### Message passing in PyG
PyG provides the class <code>MessagePassing</code> for the *message*, *aggregate* and *update* functions:
- <code>MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)</code>: defines the **aggregation function** to use (aggr="add", "mean" or "max") and the flow direction of message passing (either "source_to_target" or "target_to_source"). 
- <code>MessagePassing.propagate(edge_index, size=None, **kwargs)</code>: this is the **initial call to start propagating messages**.
- <code>MessagePassing.message(...)</code>: defines the **message function** $\phi$ to node *i* for each edge (j, i) if flow="source_to_target" and (i,j) if flow="target_to_source". It can take any argument which was initially passed to <code>propagate()</code>.
- <code>MessagePassing.update(...)</code>: this is the **update function** $\gamma$.

#### Example 1:  GCN layer from Kipf and Welling 

The message passing is defined as:
![image](images/GCN.png)

Let's break down the function (also from first to last operation):
1. <u>Message passing</u> function $\phi$:
    - Select one single neighbor node: each feature of the node is weighted through the weights in the vector **W**. Clearly, <code>size(W)==size(x)</code>
    - the new weighted feature vector is normalized by sqrt(degree(*i*)) * sqrt(degree(*j*))
2. <u>Aggregation</u> function *O*:
    - sum all the transformed node feature vectors
3. <u>Update</u> function $\gamma$:
    - apply a bias vector **b** (clearly, <code>size(b)==size(x)</code> )


In [1]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j