*Geometric Deep Learning* [has become increasingly popular over the past few years](https://twitter.com/prlz77/status/1178662575900368903), and my curiosity about it has been growing significantly in the past few months.  
In this post, I will try to compress the main intuitions (at a very high level) behind *Graph neural networks*, so that it will be easier to get to more detailed aspects in the future posts.
<!-- TEASER_END -->

<h3><strong>Table of contents:</strong></h3>

1. [What is a graph?](#what-is-a-graph)
2. [Node, edge and graph level information](#information-level)
3. [Graphs applications](#modelling-examples)
4. [Graph Machine Learning Tasks and Challenges](#GML-tasks)
5. [Computation Graphs, Message Passing and Aggregation](#computation-graphs)
6. [Transductive and inductive setting](#transductive-inductive)
6. [References](#references)

<h2><strong>What is a graph?</strong></h2> <a class="anchor" id="what-is-a-graph"></a>
    
A graph is a data structure made of *nodes (or vertices)* connected by *edges (or links)*; we can simply say that edges represent the relationships between nodes.  
Depending on the way the edges link the nodes, we can classify graphs as [*directed*](https://en.wikipedia.org/wiki/Directed_graph) (edges direction matters and is specified at each edge) and [*undirected*](https://www.cpp.edu/~ftang/courses/CS241/notes/graph.htm) (edges have no specified direction and can be traversed how we desire).  
<div align="center">
    <img src="/images/undirected_directed_graph.png" alt="Drawing" width="300"/>
</div>

Connectivity is another way to do network classification: for example, a graph in which all the nodes are connected to all the others is called [*complete*](https://en.wikipedia.org/wiki/Complete_graph), while a graph without any directed cylces is categorized as [*directed acyclic*](https://en.wikipedia.org/wiki/Directed_acyclic_graph).
<div align="center">
    <img src="/images/complete_acyclic_graph.png" alt="Drawing" width=300/>
</div>

There are many different graph classifications but the important thing to bear in mind is that a graph can commonly be described by its connectivity information with an [*adjacency matrix*](https://en.wikipedia.org/wiki/Adjacency_matrix#:~:text=In%20graph%20theory%20and%20computer,or%20not%20in%20the%20graph.&text=If%20the%20graph%20is%20undirected,the%20adjacency%20matrix%20is%20symmetric.), that simply stores a "1" if two nodes are connected (like node "a" and node "c" in the image below) or a "0" if they are not. 
<br>
<div align="center">
    <img src="/images/adjacency_matrix.png" alt="Drawing" width=400/> 
</div>

<h2><strong>Node, edge and graph level information</strong></h2> <a class="anchor" id="information-level"></a>   
<img align='left' src="/images/embedding.png" alt="Drawing" width=400/>

Each piece of the graph carries a different information level: 
- **node**, with attributes such as number of neighbours or node class 
- **edge**, with directions and attributes like edge weight or edge class
- **graph information**, with global attributes like number of nodes or shortest paths

Furthermore, each element can be seen as a place in which we can store information, creating different types of [embedding](https://developers.google.com/machine-learning/crash-course/embeddings/video-lecture#:~:text=An%20embedding%20is%20a%20relatively,can%20translate%20high%2Ddimensional%20vectors.&text=Ideally%2C%20an%20embedding%20captures%20some,learned%20and%20reused%20across%20models.).  
This is particularly important in the *deep learning* domain, since embeddings can be learned and treated similarly to weights to train.

<h2><strong>Graphs applications</strong></h2> <a class="anchor" id="modelling-examples"></a>

As stated in the magnificent article [*A gentle introduction to Graph Neural Networks*](https://distill.pub/2021/gnn-intro/#graph-to-tensor)[<sup>1</sup>](#fn1) (from which I have taken most of the inspiration for this post), *"graphs are all around us"*, and the followings are just a few examples of what can be represented using them: 
- **interactions** (like [social media networks](https://brilliant.org/wiki/social-networks/), [trade networks](https://wits.worldbank.org/CountryNetwork.aspx?lang=en) or document citation networks. The [CORA](https://relational.fit.cvut.cz/dataset/CORA) dataset is a good example of the latter)
- **fraud detection systems** ([here](https://blog.careem.com/en/crazywall-graph-based-identity-fraud-detection/) a graph based identity fraud detection is described)
- **chemical molecular structures** or interactions between proteins (like the [PPI](https://paperswithcode.com/dataset/ppi) dataset)
- **recommender systems** (here's [a review of graph learning based recommender systems](https://arxiv.org/pdf/2105.06339.pdf))
- **manifolds** (an application of geometric deep learning on manifold is shown in [this short video](https://www.youtube.com/watch?v=-b0e41H4J_A))

Even **images** and **text** can be represented using a graph!  
An image is a graph in which each node represents a pixel, and each edge shows the connection between that pixel and the adjacent ones.  
A node representing a non-border pixel (like the central one in the image below) always has 8 linked nodes, and the vector stored within the node has 3 dimensions representing the RGB channel.
<div>
    <img src="/images/image_to_graph.png" alt="Drawing" width=400/> 
</div>

When dealing with text, [tokenization](https://nlp.stanford.edu/IR-book/html/htmledition/tokenization-1.html) is one of the fundamental pre-processing steps, so that we can use characters, words and similar as separate elements of a sequence. Each of these elements can potentially be modelled as a node of a graph.
<div>
    <img src="/images/text_to_graph.png" alt="Drawing" width=400/> 
</div>
<br>

<h2><strong>Graph Machine Learning: Tasks and Challenges</strong></h2> <a class="anchor" id="GML-tasks"></a>

The types of prediction that we can do when working on a graph happen at the usual three main levels:  

- **Graph level** predictions (where we try to predict a property of the entire graph, and the label is assigned to the entire graph itself)
- **Node level** predictions (where we want to predict the role or identity of a node. A typical example is the [Zachary's Karate Club](https://en.wikipedia.org/wiki/Zachary%27s_karate_club))
- **Edge level** predictions (where we make predictions on the relationships between the graph's nodes)  

The first step to take when facing one of the above problems is to find the correct way to represent the data and make them understandable for the model we are using.  
In my experience, a traditional approach was just to use the graph modelling as a "feature extractor" for traditional machine learning models, adding predictors coming from network topology or interactions simulations.  
  
Anyway, to fully leverage the graph machine learning potentials, it is necessary to represent them in a way that is compatible with neural networks so that the graph structure itself becomes a source of information for the model. The neural network architecture that is able to do that is the Graph Neural Network.  
Machine Learning models are typically fed with regular arrays as input, so it is not trivial to find a way to represent a graph (with all the information levels mentioned above and the consequent interactions) for their training.  
The adjacency matrix could be used, but it presents few problems (very well explained in [this video](https://youtu.be/JtDgmmQ60x8?t=411) from the [Pytorch Geometric Tutorial project](https://antoniolonga.github.io/Pytorch_geometric_tutorials/index.html)[<sup>2</sup>](#fn2)):  
- **sparsity** and consequent space inefficiency
- it depends on nodes ordering and **it is not permutation invariant**, which means that many adjacency matrices can encode the same connectivity but obviously they would produce different results after being passed through a neural network
- it does not handle a **graph change** in size: if a new node appears on the graph, the previously trained model becomes useless since it can not handle the new shape  

This is why other steps like *message passing* and *aggregation* are taken to train Graph Neural Networks.

<h2><strong>Computation Graphs, Message Passing and Aggregation</strong></h2> <a class="anchor" id="computation-graphs"></a>  

The main intuition behind Graph Neural Networks is that they manage to learn structural information about the graphs, based on the assumptions that neighboring nodes (nodes that are connected by an edge) share similar properties; therefore, **the computation graph of a node is indeed defined by its neighbors, and every node has its own computation graph** as shown in the example below.

<div align="center">
    <img src="/images/computation_graph.png" alt="Drawing" width=600/> 
</div>

The image shows what a Graph Neural Network with 3 different layers does to a single node.  
Each layer can be seen as a progressive degree of distance from the target node:
- Layer 0 contains the node features of the most distant neighbors (in this case neighbors at degree 2)
- Layer 1 which is a sort of hidden layer that passes the aggregated information to node F (degree 1 distance)
- Layer 2 which in this case is the final layer that returns the representation of node A

From the image, we can see that both Layer 0 and Layer 1 include a neural network (usually a simple [Multi-layer perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron) or a [Recurrent Neural Network](https://en.wikipedia.org/wiki/Recurrent_neural_network)) executing the function $F(x_j) = \mathbf{W}_j \cdot x_j + b$.  
In the article [*Graph Neural Networks for Novice Math Fanatics*](https://rish16.notion.site/Graph-Neural-Networks-for-Novice-Math-Fanatics-c51b922a595b4efd8647788475461d57)[<sup>3</sup>](#fn3), **message passing for a Graph Neural Network layer** is defined as: 
>**the process of taking node features of the neighbours, transforming them, and "passing" them to the source node. This process is repeated, in parallel, for all nodes in the graph. In that way, all neighbourhoods are examined by the end of this step.**

Every layer terminates with an **aggregation step** that allows to perform an invariant operation (it does not depend on the nodes order) like *sum*, *average*, *min* or *max* before "passing the message".  
When getting the target node, the aggregated messages are combined with the node's features, so that the embedded representation of the node containes the information coming from both the node's neighbors and the node itself with a simple addition or a concatenation.  
For example, using addition and summarising the whole process to obtain the new representation $h_i$ of node $i$ with a formula:  
  
  
$$\large h_i = \sigma(K(H(x_i) + \bar{m}_i)))$$  

where:
- $x_i$ represents the features at node $i$, that will be combined with the aggregated messages
- $m_i$ is the message aggregation
-$H$ is the simple neural network mentioned before and applied on the node features
- $K$ is another neural network used to project to another dimension both the node features and the aggregated messages (that in this case are summed together)
- $\sigma$ is an activation function like [*Relu*](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) or [*Leaky Relu*](https://paperswithcode.com/method/leaky-relu#:~:text=Leaky%20Rectified%20Linear%20Unit%2C%20or,instead%20of%20a%20flat%20slope.&text=learnt%20during%20training.-,This%20type%20of%20activation%20function%20is%20popular%20in%20tasks%20where,example%20training%20generative%20adversarial%20networks.)
  
To conclude, a single Graph Neural Network layer applied on a single node $i$ and using addition, can be formulated as:  

$$\large h_i = \sigma(W_1\cdot h_i + \sum_{j \in \mathcal{N}_i}\mathbf{W}_2\cdot h_j )$$
  
since it applies an activation on the sum of the features of $i$ multiplied by the weights at layer 1 ($W_1$) and the features of the neighbors $j$ multiplied by the weights at layer 2 ($W_2$).  
The whole process obviously increases in complexity when given a full adjacency matrix, which is used when training the model.

<h2><strong>Transductive and inductive setting</strong></h2> <a class="anchor" id="computation-graphs"></a>  

<h2><strong>References</strong></h2> <a class="anchor" id="references"></a>

<p id="fn1">[1] Sanchez-Lengeling, et al., <a href="https://distill.pub/2021/gnn-intro/">A Gentle Introduction to Graph Neural Networks</a>, Distill, 2021.</p>
<p id="fn2">[2] Longa A., Santin G., Pellegrini G., <a href="https://antoniolonga.github.io/Pytorch_geometric_tutorials/index.html">Pytorch Geometric Tutorial</a>, 2021.</p>
<p id="fn3">[3] Anand R., <a href="https://rish16.notion.site/Graph-Neural-Networks-for-Novice-Math-Fanatics-c51b922a595b4efd8647788475461d57">Graph Neural Networks for Novice Math Fanatics</a>, 2021.</p>