# Creating a Message Passing Layer
* There are 3 critcal functions needed to define a PyG Message Passing Layer: `forward`, `message`, and `aggregate`.
* The job of a message passing layer is to update the current feature representation or embedding of each node in a graph by propagating and transforming information within the graph. 
* Overall, the general paradigm of a message passing layers is:
    1. pre-processing
    2. message passing / propagation
    3.post-processing.

## Forward Function
The `forward` function handles the pre and post-processing of node features / embeddings, as well as initiates message passing by calling the `propagate` function. 
* We can place the logic for updating node embeddings after message passing and within the `forward` function. To be more specific, after information is propagated (message passing), we can further transform the node embeddings outputed by `propagate`. Therefore, the output of `forward` is exactly the node embeddings after one GNN layer.



## Propagate Function
* The `propagate` function encapsulates the message passing process.
* It does so by calling three important functions:
    1. `message`
    2. `aggregate`
    3. `update`

```
def propagate(edge_index, size=size, x=(x_i, x_j), extra=(extra_i, extra_j)):
```
Calling `propagate` initiates the message passing process.

  - `edge_index` is passed to the forward function and captures the edge structure of the graph.
  - `x=(x_i, x_j)` represents the **node features** that will be used in message passing. In order to explain why we pass the tuple `(x_i, x_j)`, we first look at how our edges are represented. For every edge $(i, j) \in \mathcal{E}$, we can differentiate $i$ as the source or central node ($x_{central}$) and j as the neighboring node ($x_{neighbor}$). 
    - $i$- indicates a central node
    - $j$- indicates a neighboring node
  
    Taking the example of message passing above, for a central node $u$ we will aggregate and transform all of the messages associated with the nodes $v$ s.t. $(u, v) \in \mathcal{E}$ (i.e. $v \in \mathcal{N}_{u}$). Thus we see, the subscripts `_i` and `_j` allow us to specifcally differenciate features associated with central nodes (i.e. nodes  recieving message information) and neighboring nodes (i.e. nodes passing messages). 


  - `extra=(extra_i, extra_j)` represents additional information that we can associate with each node beyond its current feature embedding. In fact, we can include as many additional parameters of the form `param=(param_i, param_j)` as we would like. Again, we highlight that indexing with `_i` and `_j` allows us to differentiate central and neighboring nodes. 

  - `size`- the size (N, M) of the assignment matrix in case edge_index is a LongTensor. If set to None, the size will be automatically inferred and assumed to be quadratic. 


  The output of the `propagate` function is a matrix of node embeddings after the message passing process and has shape $[N, d]$.

### Message Function
```
def message(x_j, ...):
```
The `message` function is called by propagate and constructs the messages from
neighboring nodes $j$ to central nodes $i$ for each edge $(i, j)$ in *edge_index*. This function can take any argument that was initially passed to `propagate`. Furthermore, we can again differentiate central nodes and neighboring nodes by appending `_i` or `_j` to the variable name, .e.g. `x_i` and `x_j`. Looking more specifically at the variables, we have:

  - `x_j` represents a matrix of feature embeddings for all neighboring nodes passing their messages along their respective edge (i.e. all nodes $j$ for edges $(i, j) \in \mathcal{E}$). Thus, its shape is $[|\mathcal{E}|, d]$!


  Critically, we see that the output of the `message` function is a matrix of neighboring node embeddings ready to be aggregated, having shape $[|\mathcal{E}|, d]$.

# Aggregate Function
```
def aggregate(self, inputs, index, dim_size = None):
```
Lastly, the `aggregate` function is used to aggregate the messages from neighboring nodes. Looking at the parameters we highlight:

  - `inputs` represents a matrix of the messages passed from neighboring nodes (i.e. the output of the `message` function).
  - `index` has the same shape as `inputs` and tells us the central node that corresponding to each of the rows / messages $j$ in the `inputs` matrix. Thus, `index` tells us which rows / messages to aggregate for each central node.

  The output of `aggregate` is of shape $[N, d]$.


For additional resources refer to the PyG documentation for implementing custom message passing layers: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html

# message_and_aggregate
* fuses the message() and aggregate() functions into a single computation step), which gets called whenever it is implemented and receives a SparseTensor as input for edge_index

In [None]:
from torch_geometric.nn import MessagePassing
from torch_sparse import matmul

class GINConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x)
        return MLP((1 + eps) x + out)

    # def message(self, x_j):
    #     return x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)

    # this should also work
    def message_and_aggregate(self, edge_index, x):
        return matmul(edge_index, x, reduce=self.aggr)

## Stacking GNN layer

In [None]:
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):
        super(GNNStack, self).__init__()
        conv_model = self.build_conv_model(args.model_type)
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        assert (args.num_layers >= 1), 'Number of layers is not >=1'
        for l in range(args.num_layers-1):
            self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), 
            nn.Linear(hidden_dim, output_dim))

        self.dropout = args.dropout
        self.num_layers = args.num_layers

        self.emb = emb

    def build_conv_model(self, model_type):
        if model_type == 'GraphSage':
            return GraphSage
        elif model_type == 'GAT':
            # When applying GAT with num heads > 1, you need to modify the 
            # input and output dimension of the conv layers (self.convs),
            # to ensure that the input dim of the next layer is num heads
            # multiplied by the output dim of the previous layer.
            # HINT: In case you want to play with multiheads, you need to change the for-loop that builds up self.convs to be
            # self.convs.append(conv_model(hidden_dim * num_heads, hidden_dim)), 
            # and also the first nn.Linear(hidden_dim * num_heads, hidden_dim) in post-message-passing.
            return GAT

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
          
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout,training=self.training)

        x = self.post_mp(x)

        if self.emb == True:
            return x

        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)

## Scatter
Reduces all values from the `src` tensor into `out` at the indices specified in the `index` tensor along a given axis `dim`.

* application- aggregation function in the aggregate method from MessagePassing class

For each value in src, its output index is specified by its index in src for dimensions outside of dim and by the corresponding value in index for dimension dim. The applied reduction is defined via the reduce argument.

* `src` – The source tensor.
* `index` – The indices of elements to scatter.
* `dim` – The axis along which to index. (default: -1)
* `out` – The destination tensor.
* `dim_size` – If out is not given, automatically create output with size dim_size at dimension dim. If dim_size is not given, a minimal sized output tensor according to index.max() + 1 is returned.
* `reduce` – The reduce operation ("sum", "mul", "mean", "min" or "max"). (default: "sum")



In [None]:
import torch_scatter

# out = scatter(src, index, dim=1, reduce="sum")
# critical to set dim!
torch_scatter.scatter(inputs, index, dim=dim, dim_size=dim_size, reduce="mean")

# Models

In [None]:
def train(dataset, args):
    
    print("Node task. test set size:", np.sum(dataset[0]['test_mask'].numpy()))
    print()
    test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    # build model
    model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, 
                            args)
    scheduler, opt = build_optimizer(args, model.parameters())

    # train
    losses = []
    test_accs = []
    best_acc = 0
    best_model = None
    for epoch in trange(args.epochs, desc="Training", unit="Epochs"):
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()
            pred = model(batch)
            label = batch.y
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs  # we care about #graphs in batch (not #nodes)
        total_loss /= len(loader.dataset)
        losses.append(total_loss)

        if epoch % 10 == 0:
          test_acc = test(test_loader, model)
          test_accs.append(test_acc)
          if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(model)
        else:
          test_accs.append(test_accs[-1])
    
    return test_accs, losses, best_model, best_acc, test_loader

def test(loader, test_model, is_validation=False, save_model_preds=False, model_type=None):
    test_model.eval()

    correct = 0
    # Note that Cora is only one graph!
    for data in loader:
        with torch.no_grad():
            # max(dim=1) returns values, indices tuple; only need indices
            pred = test_model(data).max(dim=1)[1]
            label = data.y

        mask = data.val_mask if is_validation else data.test_mask
        # node classification: only evaluate on nodes in test set
        pred = pred[mask]
        label = label[mask]

        if save_model_preds:
          print ("Saving Model Predictions for Model Type", model_type)

          data = {}
          data['pred'] = pred.view(-1).cpu().detach().numpy()
          data['label'] = label.view(-1).cpu().detach().numpy()

          df = pd.DataFrame(data=data)
          # Save locally as csv
          df.to_csv('CORA-Node-' + model_type + '.csv', sep=',', index=False)
            
        correct += pred.eq(label).sum().item()

    total = 0
    for data in loader.dataset:
        total += torch.sum(data.val_mask if is_validation else data.test_mask).item()

    return correct / total
  
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d


## GraphSage Implementation

For our first GNN layer, we will implement the well known GraphSage ([Hamilton et al. (2017)](https://arxiv.org/abs/1706.02216)) layer! 

For a given *central* node $v$ with current embedding $h_v^{l-1}$, the message passing update rule to tranform $h_v^{l-1} \rightarrow h_v^l$ is as follows: 

\begin{equation}
h_v^{(l)} = W_l\cdot h_v^{(l-1)} + W_r \cdot AGG(\{h_u^{(l-1)}, \forall u \in N(v) \})
\end{equation}

where $W_1$ and $W_2$ are learanble weight matrices and the nodes $u$ are *neighboring* nodes. Additionally, we use mean aggregation for simplicity:

\begin{equation}
AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) = \frac{1}{|N(v)|} \sum_{u\in N(v)} h_u^{(l-1)}
\end{equation}

One thing to note is that we're adding a **skip connection** to our GraphSage implementation through the term $W_l\cdot h_v^{(l-1)}$. 

Lastly, $\ell$-2 normalization of the node embeddings is applied after each iteration.

In [None]:
from torch_geometric.nn.conv import MessagePassing

class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = nn.Linear(self.in_channels, self.out_channels, bias=True)
        self.lin_r = nn.Linear(self.in_channels, self.out_channels, bias=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        """ GraphSAGE forward pass
        
        Parameters
        ----------
        x : torch.FloatTensor
            nodes hidden representation of shape (N, hidden_dim)

        edge_index : torch.LongTensor/torch_sparse.SparseTensor
            defines the underlying graph connectivity/message passing flow


        size (optional) : tuple
            The size (N, M) of the assignment matrix in case edge_index is a LongTensor.
            This argument is ignored in case edge_index is a torch_sparse.SparseTensor

        Returns
        -------
        torch.FloatTensor
            node embeddings of nodes in the batch
        """

        out = self.propagate(edge_index, x=(x, x))  # get aggregated output using propage function
        out = self.lin_r(out)
        out += self.lin_l(x)  # skip connection
        if self.normalize:
            out = torch.nn.functional.normalize(out, p=2.0, dim=1)

        return out

    def message(self, x_j):
        """ GraphSAGE message

        Paramerters
        -----------
        x_j : torch.FloatTensor
            x_i's neighbors' hidden representation (E, hidden_dim)
        
        Returns
        -------
        torch.FloatTensor
            message of node x_i before aggregation

        """

        out = x_j  # the message is the neighbor's (previous) state

        return out

    def aggregate(self, inputs, index, dim_size = None):
        """ Aggregate messages from all neighbors based on index

        Parameters
        ----------
        inputs :  totch.LongTensor
            Inputs to aggregate, i.e. message representations from neighbors

        index : torch.LongTensor
             The indices of elements to scatter (aggregate by index)

        dim_size (optional): int
            Output with size dim_size at dimension dim

        Returns
        -------
        torch.FloatTensor
            Final aggregated representation of a message
        """

        out = None

        # The axis along which to index number of nodes.
        node_dim = self.node_dim

        out = torch_scatter.scatter(inputs, index, dim=node_dim, dim_size=dim_size, reduce="mean")

        return out


## GAT Implementation

Attention mechanisms have become the state-of-the-art in many sequence-based tasks such as machine translation and learning sentence representations. One of the major benefits of attention-based mechanisms is their ability to focus on the most relevant parts of the input to make decisions. In this problem, we will see how attention mechanisms can be used to perform node classification over graph-structured data through the usage of Graph Attention Networks (GATs) ([Veličković et al. (2018)](https://arxiv.org/abs/1710.10903)).

The building block of the Graph Attention Network is the graph attention layer, which is a variant of the aggregation function. Let $N$ be the number of nodes and $F$ be the dimension of the feature vector for each node. The input to each graph attentional layer is a set of node features: $\mathbf{h} = \{\overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N}$\}, $\overrightarrow{h_i} \in R^F$. The output of each graph attentional layer is a new set of node features, which may have a new dimension $F'$: $\mathbf{h'} = \{\overrightarrow{h_1'}, \overrightarrow{h_2'}, \dots, \overrightarrow{h_N'}\}$, with $\overrightarrow{h_i'} \in \mathbb{R}^{F'}$.

We will now describe how this transformation is performed for each graph attention layer. First, a shared linear transformation parametrized by the weight matrix $\mathbf{W} \in \mathbb{R}^{F' \times F}$ is applied to every node. 

Next, we perform self-attention on the nodes. We use a shared attention function $a$:
\begin{equation} 
a : \mathbb{R}^{F'} \times \mathbb{R}^{F'} \rightarrow \mathbb{R}.
\end{equation}

that computes the attention coefficients capturing the importance of node $j$'s features to node $i$:
\begin{equation}
e_{ij} = a(\mathbf{W_l}\overrightarrow{h_i}, \mathbf{W_r} \overrightarrow{h_j})
\end{equation}

The most general formulation of self-attention allows every node to attend to all other nodes which drops all structural information. However, to utilize graph structure in the attention mechanisms, we use **masked attention**. In masked attention, we only compute attention coefficients $e_{ij}$ for nodes $j \in \mathcal{N}_i$ where $\mathcal{N}_i$ is some neighborhood of node $i$ in the graph.

To easily compare coefficients across different nodes, we normalize the coefficients across $j$ using a softmax function:
\begin{equation}
\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp(e_{ik})}
\end{equation}

For this problem, our attention mechanism $a$ will be a single-layer feedforward neural network parametrized by a weight vectors $\overrightarrow{a_l} \in \mathbb{R}^{F'}$ and $\overrightarrow{a_r} \in \mathbb{R}^{F'}$, followed by a LeakyReLU nonlinearity (with negative input slope 0.2). Let $\cdot^T$ represent transposition and $||$ represent concatenation. The coefficients computed by our attention mechanism may be expressed as:

\begin{equation}
\alpha_{ij} = \frac{\exp\Big(\text{LeakyReLU}\Big(\overrightarrow{a_l}^T \mathbf{W_l} \overrightarrow{h_i} + \overrightarrow{a_r}^T\mathbf{W_r}\overrightarrow{h_j}\Big)\Big)}{\sum_{k\in \mathcal{N}_i} \exp\Big(\text{LeakyReLU}\Big(\overrightarrow{a_l}^T \mathbf{W_l} \overrightarrow{h_i} + \overrightarrow{a_r}^T\mathbf{W_r}\overrightarrow{h_k}\Big)\Big)}
\end{equation}

For the following questions, we denote `alpha_l` = $\alpha_l = [...,\overrightarrow{a_l}^T \mathbf{W_l} \overrightarrow{h_i},...] \in \mathcal{R}^n$ and `alpha_r` = $\alpha_r = [..., \overrightarrow{a_r}^T \mathbf{W_r} \overrightarrow{h_j}, ...] \in \mathcal{R}^n$.


At every layer of GAT, after the attention coefficients are computed for that layer, the aggregation function can be computed by a weighted sum of neighborhood messages, where weights are specified by $\alpha_{ij}$.

Now, we use the normalized attention coefficients to compute a linear combination of the features corresponding to them. These aggregated features will serve as the final output features for every node.

\begin{equation}
h_i' = \sum_{j \in \mathcal{N}_i} \alpha_{ij} \mathbf{W_r} \overrightarrow{h_j}.
\end{equation} 

### Multi-Head Attention
To stabilize the learning process of self-attention, we use multi-head attention. To do this we use $K$ independent attention mechanisms, or ``heads`` compute output features as in the above equations. Then, we concatenate these output feature representations:

\begin{equation}
    \overrightarrow{h_i}' = ||_{k=1}^K \Big(\sum_{j \in \mathcal{N}_i} \alpha_{ij}^{(k)} \mathbf{W_r}^{(k)} \overrightarrow{h_j}\Big)
\end{equation}

where $||$ is concentation, $\alpha_{ij}^{(k)}$ are the normalized attention coefficients computed by the $k$-th attention mechanism $(a^k)$, and $\mathbf{W}^{(k)}$ is the corresponding input linear transformation's weight matrix. Note that for this setting, $\mathbf{h'} \in \mathbb{R}^{KF'}$.

In [None]:
class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.lin_l = None
        self.lin_r = None
        self.att_l = None
        self.att_r = None

        self.lin_l = nn.Linear(in_channels, self.heads * out_channels)

        self.lin_r = self.lin_l

        self.att_l = Parameter(torch.randn(self.heads, out_channels), requires_grad=True)
        self.att_r = Parameter(torch.randn(self.heads, out_channels), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)

    def forward(self, x, edge_index, size = None):
        """GAT layer forward pass

        Parameters
        ----------
        x : torch.FloatTensor
            node embedding matrix (N, C)
        
        edge_index : torch.LongTensor/torch_sparse.SparseTensor
            defines the underlying graph connectivity/message passing flow
            shape: (2, E)

        size (optional) : tuple
            The size (N, M) of the assignment matrix in case edge_index is a LongTensor.
            This argument is ignored in case edge_index is a torch_sparse.SparseTensor

        Returns
        -------
        torch.FloatTensor
            node embeddings of nodes in the batch

        """
        
        H, C = self.heads, self.out_channels

        x_i_embed = self.lin_l(x).view(-1, H, C)  # source node representation  (N, H, C)
        x_j_embed = self.lin_r(x).view(-1, H, C)  # target node representation  (N, H, C)
        alpha_l = (x_i_embed * self.att_l).sum(-1)  # (N, H)  (alpha for each node v and head h)
        alpha_r = (x_j_embed * self.att_r).sum(-1)  # (N, H)  (alpha for each node neighbor u and head h)
        out = self.propagate(edge_index, x=(x_i_embed, x_j_embed), alpha=(alpha_l, alpha_r))  # get aggregated output using propage function
        out = out.view(-1, H * C)

        return out


    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        """ GAT message

        Paramerters
        -----------
        x_j : torch.FloatTensor
            x_i's neighbors' hidden representation (E, H, C)  (C- hidden dim)

        alpha_j : torch.FloatTensor
            pre-message representation of node neighbor j (before applying leaky relu and softmax)
            shape: (E, H)

        alpha_i : torch.FloatTensor
            pre-message representation message of node i (before applying leaky relu and softmax)
            shape: (E, H)

        index : torch.LongTensor/torch_sparse.SparseTensor
            defines the underlying graph connectivity/message passing flow
            for softmax (normalize by node's neighbors)
            shape: (E, 2)

        ptr : LongTensor (optional)
            If given, computes the softmax based on sorted inputs in CSR representation.

        size_i : 
        
        Returns
        -------
        torch.FloatTensor
            message of node x_i before aggregation

        """
        e_ij = F.leaky_relu(alpha_i + alpha_j, negative_slope=self.negative_slope)
        if ptr:
            alpha_ij = softmax(ptr, index)   # softmax normalized by neighborhood
        else:
            alpha_ij = softmax(e_ij, index)  # softmax normalized by neighborhood

        # apply dropout on attention weights
        alpha_ij = F.dropout(alpha_ij, self.dropout, training=self.training)  #   or (E, H)?
        out = x_j * alpha_ij.unsqueeze(-1)  # use unsqueeze for broadcasting  (E, H, C)

        return out


    def aggregate(self, inputs, index, dim_size = None):
        """ Aggregate messages from all neighbors based on index

        Parameters
        ----------
        inputs :  totch.LongTensor
            Inputs to aggregate, i.e. message representations from neighbors

        index : torch.LongTensor
             The indices of elements to scatter (aggregate by index)

        dim_size (optional): int
            Output with size dim_size at dimension dim

        Returns
        -------
        torch.FloatTensor
            Final aggregated representation of a message
        """
        node_dim = self.node_dim
        out = torch_scatter.scatter(inputs, index, dim=node_dim, dim_size=dim_size, reduce="sum")
    
        return out