# Graph Neural Networks (GNNs)
## Part I: Message Passing Neural Networks (MPNNs)
We are going to implement few MPNNs for molecular property prediciton. It's recommended that you're familiar with the recent lectures on GNNs.

# Packages for GNNs
There two very popular packages for GNNs that uses pytorch as a backend:
1. [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html).
2. [Deep Graph Library](https://www.dgl.ai/pages/start.html) (along with [dgl-lifesci](https://lifesci.dgl.ai/install/index.html).
The former is more stable, the latter has a convenient extension [dgl-lifesci](https://lifesci.dgl.ai/generated/dgllife.utils.CanonicalAtomFeaturizer.html) for molecular data and is generally much more user-friendly. For convenience, we are going to use all three packages, so install appropriate versions of them, please (I recommend installing with pip). If you have issues with installing rdkit (required by dgl-lifesci), you can install rdkit using pip (pip install rdkit).

## Installation commands
The following commands should work for the `ml` conda env from the `01_zasady_wstep.ipynb`.

```
conda install -c dglteam/label/cu117 dgl
conda install -c conda-forge dgllife
conda install pyg -c pyg

conda install libboost=1.73.0 boost=1.73.0 boost-cpp=1.73.0
pip install rdkit==2022.3.5

conda install -c conda-forge torchmetrics
conda install -c conda-forge wandb
```

# Molecular graphs
(Copied from [mldd23 repository](https://github.com/gmum/mldd23/blob/main/labs/L3-graph-neural-networks/laboratory.ipynb))
In mathematics, a graph is an object that consists of a set of vertices (nodes) connected with edges, i.e. $\mathcal{G} = (V, E)$, where $V = \{ v_i: i \in \{1, 2, \dots, N \} \}$ and $E \subseteq \{ (v_i, v_j):\, v_i,v_j \in V \}$.

Molecular graphs are a special class of graphs, where besides nodes (denoting atoms) and edges (denoting chemical bonds), we have an additional information about atom types and sometimes also bond types. We can assume that we have an additional set of node/atom features encoded as a matrix $X$, where $X_{ij}$ is the $j$-th feature of the $i$-th atom. As atomic features, we can have one-hot encoded atom symbols (a vector containing zeros on all positions besides the position that corresponds to the atom symbol), the number of implicit hydrogens bonded with this atom, or the number of heavy neighbors (atoms other than hydrogens bonded to the given atom).

Egdes/bonds can be encoded in two different ways. One method is to use an adjacency matrix $A$, where $A_{ij}=1$ if nodes/atoms $v_i$ nad $v_j$ are connected ($A_{ij}=0$ otherwise). In the case of sparse matrices, a more useful encoding is a list of pairs of connected atoms (a list of index pairs). This latter enocding is used by the PyTorch-Geometric library.

In practice, a molecular graph can be described by two matrices: $X \in \mathbb{R}^{N \times F}$ and $E \in \{0, 1,\dots,N-1\}^{2 \times N}$, where $N$ is the number of atoms, and $F$ is the number of atomic features.
<img src="resources/mol_graph.png" height="500" />

# Dataset

We are going to use FreeSolv dataset that contains 642 hydration free energy values for small molecules. The goal is to predict the [hydration free energy](https://en.wikipedia.org/wiki/Hydration_energy) of a given molecule. It's a very commonly used dataset for benchmarking molecular property prediction models. It's small, so we can minimize our co2 footprint and time spent on training. 

Molecules in most chemical datasets are represented with SMILES. SMILES is a linearization of the molecular graph, it's pretty convenient and can even be used as an input to text-based models. Fortunately, dgllife provides a fancy FreeSolv dataset wrapper that will 1) transform the SMILES into a molecular graph, and 2) encode the nodes and edges with some sensible chemical features (like atom types, bond type etc.) with node and edge features, so we don't really need to care about it.

In [73]:
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer, SMILESToBigraph
from dgllife.data import FreeSolv
import torch
import dgl

node_featurizer = CanonicalAtomFeaturizer()
edge_featurizer = CanonicalBondFeaturizer(self_loop=True)
dataset = FreeSolv(
    smiles_to_graph=SMILESToBigraph(
        node_featurizer=node_featurizer,
        edge_featurizer=edge_featurizer,
        add_self_loop=True, # well... some of the molecules in the dataset contain no edges, so adding the self-loop (edge from node to itself) makes the future MPNN implementations simpler.
    ),
)

Processing dgl graphs from scratch...


## Playground

In [74]:
smiles, graph, label = dataset[0]
smiles, graph, label

('CN(C)C(=O)c1ccc(cc1)OC',
 Graph(num_nodes=13, num_edges=39,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 tensor([-11.0100]))

We see that the dataset item consist of a SMILES string, a graph, and a label. The graph is a [DGLGraph](https://docs.dgl.ai/en/0.8.x/api/python/dgl.DGLGraph.html) object that contains node and edge features. We can access them with the following code:

In [75]:
graph.ndata['h'].shape  # node features

torch.Size([13, 74])

In [76]:
graph.edata['e'].shape  # edge features

torch.Size([39, 13])

In [77]:
start_nodes, end_nodes = graph.edges()  # edges. Note that edges are directed, so we have two edges for each bond. Moreover, we have self-loops, to easily handle molecules with only one atom.
edges = torch.stack([start_nodes, end_nodes], dim=1)
edges

tensor([[12,  0],
        [ 0, 12],
        [ 0,  2],
        [ 2,  0],
        [ 0,  4],
        [ 4,  0],
        [ 4,  7],
        [ 7,  4],
        [ 4,  9],
        [ 9,  4],
        [ 9,  6],
        [ 6,  9],
        [ 6, 10],
        [10,  6],
        [10, 11],
        [11, 10],
        [11,  3],
        [ 3, 11],
        [ 3,  8],
        [ 8,  3],
        [11,  5],
        [ 5, 11],
        [ 5,  1],
        [ 1,  5],
        [ 8,  9],
        [ 9,  8],
        [ 0,  0],
        [ 1,  1],
        [ 2,  2],
        [ 3,  3],
        [ 4,  4],
        [ 5,  5],
        [ 6,  6],
        [ 7,  7],
        [ 8,  8],
        [ 9,  9],
        [10, 10],
        [11, 11],
        [12, 12]], dtype=torch.int32)

Importantly, if we want to create a batch of graphs, we can simply treat the graphs as... a single graph with many disconnected components. The reason is that MPNN cannot pass the message between disconnected compontents, so the graphs in a batch won't influence each other. To make a batch from two graphs, we can simply run:

In [78]:
_, graph_1, _ = dataset[0]
_, graph_2, _ = dataset[1]
collated_graph = dgl.batch([graph_1, graph_2])
graph_1, graph_2, collated_graph

(Graph(num_nodes=13, num_edges=39,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=5, num_edges=13,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=18, num_edges=52,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}))

In [79]:
collated_graph.batch_num_nodes()

tensor([13,  5])

In the collated_graph, the ids corresponding to the nodes of graph_2 are shifted by the size of graph_1:

In [80]:
graph_1.nodes(), graph_2.nodes(), collated_graph.nodes()

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12], dtype=torch.int32),
 tensor([0, 1, 2, 3, 4], dtype=torch.int32),
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17],
        dtype=torch.int32))

## Split
We are going to make our split slightly harder by using [scaffold](https://hub.knime.com/infocom/extensions/jp.co.infocom.cheminfo.jchem.feature/latest/jp.co.infocom.cheminfo.jchem.bemismurckoclustering.BemisMurckoClusteringNodeFactory) (scaffold is the largest cycle in a molecule) splitting that puts molecules with similar scaffolds to the same split.

In [81]:
from dgllife.utils import ScaffoldSplitter

splitter = ScaffoldSplitter()
train, valid, test = splitter.train_val_test_split(dataset)

Start initializing RDKit molecule instances...
Start computing Bemis-Murcko scaffolds.


# Code

## Trainer

In [82]:
import copy
import numpy as np
from tqdm.autonotebook import tqdm
from dgl.dataloading import GraphDataLoader
from torchmetrics import Metric
from dgl.data import Subset
from torch import nn
from typing import Type
from typing import Dict, Any
from pathlib import Path
from abc import ABC, abstractmethod
from lab.checker import expected_mean_readout, expected_gin_layer_output, expected_sage_layer_output, \
    expected_attention_readout, expected_gine_layer_output, expected_sum_readout, expected_simple_mpnn_output


class LoggerBase(ABC):
    def __init__(self, logdir: str | Path):
        self.logdir = Path(logdir)
        self.logdir.mkdir(parents=True, exist_ok=True)

    @abstractmethod
    def log_metrics(self, metrics: Dict[str, Any], prefix: str):
        ...

    @abstractmethod
    def close(self):
        ...


class DummyLogger(LoggerBase):  # If you don't want to use any logger, you can use this one
    def log_metrics(self, metrics: Dict[str, Any], prefix: str):
        pass

    def close(self):
        pass

    def restart(self):
        pass


class MetricList:
    def __init__(self, metrics: Dict[str, Metric]):
        self.metrics = copy.deepcopy(metrics)

    def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
        for name, metric in self.metrics.items():
            metric.update(preds.detach().cpu(), targets.cpu())

    def compute(self) -> Dict[str, float]:
        metrics = {}
        for name, metric_fn in self.metrics.items():
            metrics[name] = metric_fn.compute().item()
            metric_fn.reset()
        return metrics


class Trainer:
    def __init__(
            self,
            *,
            run_dir: str | Path,
            train_dataset: Subset,
            valid_dataset: Subset,
            train_metrics: Dict[str, Metric],
            valid_metrics: Dict[str, Metric],
            model: nn.Module,
            logger: LoggerBase,
            optimizer_kwargs: Dict[str, Any],
            optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
            n_epochs: int,
            train_batch_size: int = 32,
            valid_batch_size: int = 16,
            device: str = "cuda",
            valid_every_n_epochs: int = 1,
            loss_fn=nn.MSELoss()
    ):
        self.run_dir = Path(run_dir)
        self.train_loader = GraphDataLoader(
            dataset=train_dataset,
            batch_size=train_batch_size,
            shuffle=True,
        )
        self.valid_loader = GraphDataLoader(
            dataset=valid_dataset,
            batch_size=valid_batch_size,
            shuffle=True,
        )
        self.train_metrics = MetricList(train_metrics)
        self.valid_metrics = MetricList(valid_metrics)
        self.logger = logger
        self.model = model
        self.optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
        self.n_epochs = n_epochs
        self.device = device
        self.valid_every_n_epochs = valid_every_n_epochs
        self.loss_fn = loss_fn
        self.model.to(device)

    @torch.no_grad()
    def validate(self, dataloader: GraphDataLoader, prefix: str) -> Dict[str, float]:
        previous_mode = self.model.training
        self.model.eval()
        losses = []
        for _, graphs, labels in dataloader:
            graphs = graphs.to(self.device)
            labels = labels.to(self.device)
            preds = self.model(graphs)
            loss = self.loss_fn(preds, labels)
            losses.append(loss.item())
            self.valid_metrics.update(preds, labels)
        self.model.train(mode=previous_mode)
        metrics = {"loss": np.mean(losses)} | self.valid_metrics.compute()
        self.logger.log_metrics(metrics=metrics, prefix=prefix)
        return metrics

    def train(self) -> Dict[str, float]:
        self.model.train()
        valid_metrics = {}
        for epoch in tqdm(range(self.n_epochs), total=self.n_epochs):
            for _, graphs, labels in self.train_loader:
                self.optimizer.zero_grad()
                graphs = graphs.to(self.device)
                labels = labels.to(self.device)
                preds = self.model(graphs)
                loss = self.loss_fn(preds, labels)
                loss.backward()
                self.optimizer.step()

                self.train_metrics.update(preds, labels)
                train_metrics = {"loss": loss.item()} | self.train_metrics.compute()
                self.logger.log_metrics(metrics=train_metrics, prefix="train")

                if epoch % self.valid_every_n_epochs == 0 or epoch == self.n_epochs - 1:
                    valid_metrics = self.validate(self.valid_loader, prefix="valid")

        return valid_metrics

    def test(self, dataset: Subset) -> Dict[str, float]:
        dataloader = GraphDataLoader(
            dataset=dataset,
            batch_size=16,
            shuffle=False,
        )
        return self.validate(dataloader, prefix="test")

    def close(self):  # close the logger, not really required for wandb
        self.logger.close()

# Graph Neural Networks (GNNs)
The high-level Graph Neural Network architecture we are going to use looks roughly like this:

<img src="resources/gnn.png" width="1200" />

- The Featurizer takes a molecule and transforms it to a graph with node and edge features (it happens at the level of dataset, so we don't really need to worry about that).
- In our case, we will linearly embed the node and edge features to the hidden size before applying first MPNN layer which is not captured in the diagram.
- The MPNN layer takes node (and possibly edge embeddings) and the graph structure and returns updated node embeddings. It happens in a loop.
- Then the node embeddings are aggregated by the Readout layer to obtain a graph embeddings.
- Finally, the graph embeddings are passed to the MLP to obtain the final prediction.

In [83]:
class MPNNLayerBase(ABC, nn.Module):
    def _init(self, hidden_size: int):
        """
        Attributes:
            hidden_size: the size of node (and edges) embeddings
        """
        super().__init__()
        self.hidden_size = hidden_size

    @abstractmethod
    def forward(self, 
                node_embeddings: torch.Tensor, 
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        ...


class ReadoutBase(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size

    @abstractmethod
    def forward(self,
                node_embeddings: torch.Tensor, 
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        ...


class GNN(nn.Module):
    def __init__(self,
                 node_features_size: int,
                 edge_features_size: int,
                 hidden_size: int,
                 output_size: int,
                 mpnn_layer_cls: Type[MPNNLayerBase],
                 mpnn_layer_kwargs: Dict[str, Any],
                 mpnn_n_layers: int,
                 readout_cls: Type[ReadoutBase]):
        """
        Arguments:
            node_features_size: the size of node features
            edge_features_size: the size of edge features
            hidden_size: the size of node (and edge) embeddings
            output_size: the size of the final prediction
            mpnn_layer_cls: the class of MPNN layer
            mpnn_layer_kwargs: the kwargs for the MPNN layer
            mpnn_n_layers: the number of MPNN layers
            readout_cls: the class of Readout layer
        """
        super().__init__()
        self.linear_node = nn.Linear(node_features_size, hidden_size)
        self.linear_edge = nn.Linear(edge_features_size, hidden_size)
        self.mpnn_layers = nn.ModuleList([
            mpnn_layer_cls(hidden_size=hidden_size, **mpnn_layer_kwargs)
            for _ in range(mpnn_n_layers)
        ])
        self.readout = readout_cls(hidden_size=hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            graph: a DGLGraph that contains the graph structure and node/edge features in a sparse format
        Returns:
            predictions: the final predictions
        """
        node_embeddings, edge_embeddings = graph.ndata['h'], graph.edata['e']
        node_embeddings = self.linear_node(node_embeddings)
        edge_embeddings = self.linear_edge(
            edge_embeddings)  # some of the models does not use edge features, but we won't use if-clauses for convenience.
        for layer in self.mpnn_layers:
            node_embeddings = layer(node_embeddings=node_embeddings, edge_embeddings=edge_embeddings, graph=graph)
        graph_embedding = self.readout(node_embeddings, graph)
        predictions = self.mlp(graph_embedding)
        return predictions

## Readout
Readout operation is used to aggregate node embeddings to obtain a graph embedding. There are many different readout operations, but the most popular are: sum, mean, attention, and max. We are going to implement the first three of them. Summing over nodes' embeddings seems trivial, but they're stored in a sparse format, meaning that all the nodes form all the graphs in a batch are stored in a one tensor of size `[num_nodes_1 + num_nodes_2 + ... + num_nodes_N, hidden_size]':   

In [84]:
batched_graph = dgl.batch([dataset[0][1], dataset[1][1], dataset[2][1]])
linear = nn.Linear(node_featurizer.feat_size(), 16)
node_embeddings = linear(batched_graph.ndata['h'])
node_embeddings.shape, batched_graph.batch_num_nodes()

(torch.Size([23, 16]), tensor([13,  5,  5]))

For simplicity, we will convert the sparse node embeddings to a dense format with padding. Then the shape of the node embeddings will be `[batch_size, max_num_nodes, hidden_size]`. We can use the `to_dense_batch` function from `torch_geometric` for that:

In [85]:
from typing import Tuple
from torch_geometric.utils import to_dense_batch


def to_dense_embeddings(node_embeddings: torch.Tensor, 
                        graph: dgl.DGLGraph, 
                        fill_value: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Converts sparse node embeddings to dense node embeddings with padding.
    Arguments:
        node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        graph: a batch of graphs
        fill_value: a value to fill the padding with
    Returns:
        node_embeddings: node embeddings in a dense format, i.e. [batch_size, max_num_nodes, hidden_size]
        mask: a mask indicating which nodes are real and which are padding, i.e. [batch_size, max_num_nodes]
    """
    num_nodes = graph.batch_num_nodes() # e.g. [2, 3, 3]
    indices = torch.arange(len(num_nodes), device=num_nodes.device)
    batch = torch.repeat_interleave(indices, num_nodes).long() # e.g. [0, 0, 1, 1, 1, 2, 2, 2]
    return to_dense_batch(node_embeddings, batch,
                          fill_value=fill_value)  # that's the only reason we have torch_geometric in the requirements


def to_sparse_embeddings(node_embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    Converts dense node embeddings to sparse node embeddings.
    Arguments:
        node_embeddings: node embeddings in a dense format, i.e. [batch_size, max_num_nodes, hidden_size]
        mask: a mask indicating which nodes are real and which are padding, i.e. [batch_size, max_num_nodes]
    Returns:
        node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
    """
    return node_embeddings[mask]

Now, we can simply convert the node embeddings to a dense format and sum them $x = \sum_i^n x_i$:

In [86]:
class SumReadout(ReadoutBase):
    def forward(self, 
                node_embeddings: torch.Tensor, 
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        # We can also use dgl.sum_nodes function, but let assume it's forbidden in that notebook ;)
        node_embeddings, _ = to_dense_embeddings(node_embeddings, graph)
        return node_embeddings.sum(dim=1)

In [87]:
def test_readout(readout_cls: Type[ReadoutBase], expected_output: torch.Tensor):
    torch.manual_seed(0)
    graph = dgl.batch([dataset[0][1], dataset[1][1], dataset[2][1]])
    linear = nn.Linear(node_featurizer.feat_size(), 16)
    node_embeddings = linear(graph.ndata['h'])
    readout = readout_cls(hidden_size=16)
    result = readout(node_embeddings, graph)
    assert torch.allclose(result, expected_output, atol=1e-3)

In [88]:
test_readout(SumReadout, expected_sum_readout)

### Task 1. Implement mean readout (1 point).
Implement the mean readout given by formula $x = \frac{1}{n}\sum_i^n x_i$:

In [89]:
class MeanReadout(ReadoutBase):
    def forward(self, 
                node_embeddings: torch.Tensor, 
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        # Don't use any dlg functions here
        pass
    
test_readout(MeanReadout, expected_mean_readout)

### Task 2. Implement attention readout (2 points).
Implement the attention readout given by formula $x = \sum_i^n \frac{\exp(score_i))}{\sum_j^n \exp(score_j)}x_i$, where $score_i=score\_mlp(x_i)$:

In [90]:
class AttentionReadout(ReadoutBase):
    def __init__(self, hidden_size: int):
        super().__init__(hidden_size)
        self.score_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, 
                node_embeddings: torch.Tensor, 
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        pass
    
test_readout(AttentionReadout, expected_attention_readout)

## Message Passing Neural Networks (MPNNs)
Message Passing is given by formula: 
$$
x'_i=\rho(x_i, \square_{j\in N(i)} \psi(x_j, x_i, e_{ji})),
$$ 
where $\psi$ is learnable message function, $\rho$ is learnable update, and $\square$ is aggregation function. $N(i)$ denotes the set of neighbors of node $i$. Note that in our dataset we added self-loops to every node, so $N(i)$ also contains $i$, but we don't bother with that.

### Simple MPNN
For instance, we can define a very simple MPNN layer by the following formula:
$$
x'_i=W_1x_i + W_2\sum_{j\in N(i)} W_3x_j,
$$
where W_i are linear layers with implicit bias term (we will make the bias implicit in every formula in that notebook). Let us implement this simple MPNN:

In [91]:
class SimpleMPNNLayer(MPNNLayerBase):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_1 = nn.Linear(hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.linear_3 = nn.Linear(hidden_size, hidden_size)

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        # graph is bi-directed, so we can freely swap the "start" and "end" meanings
        start_nodes, end_nodes = graph.edges(order='srcdst') # using this `order` value sorts the `start_nodes`
        messages = self.linear_3(node_embeddings[end_nodes]) # W_3x_j
        message_dense, _ = to_dense_batch(messages, start_nodes.long(), fill_value=0.0) # to make the life easier, we convert the node embeddings to dense representation
        aggregated_message = message_dense.sum(dim=1) # \sum_{j\in N(i)} W_3x_j
        aggregated_message = self.linear_2(aggregated_message) # W_2\sum_{j\in N(i)} W_3x_j
        node_embeddings = self.linear_1(node_embeddings) + aggregated_message # W_1x_i + W_2\sum_{j\in N(i)} W_3x_j
        return node_embeddings

In [92]:
def test_mpnn_layer(mpnn_layer_cls: Type[MPNNLayerBase], expected_output: torch.Tensor):
    torch.manual_seed(0)
    graph = dgl.batch([dataset[0][1], dataset[1][1]])
    linear_nodes = nn.Linear(node_featurizer.feat_size(), 4)
    linear_edges = nn.Linear(edge_featurizer.feat_size(), 4)
    node_embeddings = linear_nodes(graph.ndata['h'])
    edge_embeddings = linear_edges(graph.edata['e'])
    layer = mpnn_layer_cls(hidden_size=4)
    result = layer(node_embeddings, edge_embeddings, graph)
    assert torch.allclose(result, expected_output, atol=1e-3)

In [93]:
test_mpnn_layer(SimpleMPNNLayer, expected_simple_mpnn_output)

### Task 3. Implement GraphSAGE layer (2 points).
Implement a GraphSAGE given by the following formula:
$$
x'_i=W_1x_i + W_2\frac{1}{deg(i)}\sum_{j\in N(i)} x_j,
$$
where $deg(i) = #N(i)$ is the number of neighbors of node $i$.

In [94]:
class SAGELayer(MPNNLayerBase):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_1 = nn.Linear(hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        pass

test_mpnn_layer(SAGELayer, expected_sage_layer_output)

### Task 4. Implement GIN layer (2 points).
Implement a GIN layer given by the following formula:
$$
x'_i=mlp((1 + \epsilon)x_i + \sum_{j\in N(i)} x_j).
$$

In [95]:
class GINLayer(MPNNLayerBase):
    def __init__(self, hidden_size: int, eps: float = 0.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        pass

test_mpnn_layer(GINLayer, expected_gin_layer_output)

### Task 5. Implement GINE layer (2 points).
Implement a GINE layer given by the following formula:
$$
x'_i=mlp((1 + \epsilon)x_i + \sum_{j\in N(i)} ReLU(x_j + e_{ji})).
$$

In [96]:
class GINELayer(MPNNLayerBase):
    def __init__(self, hidden_size: int, eps: float = 0.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.relu = nn.ReLU()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        start_nodes, end_nodes, edge_ids = graph.edges(order='srcdst', form='all')
        pass

test_mpnn_layer(GINELayer, expected_gine_layer_output)

# Experiments

## Logger
We are going to use [wandb](https://wandb.ai/site) for logging. It's a very convenient tool for logging and visualizing the training process. It's free for academic use, so you can create an account and use it for your projects. If you don't want to use wandb, you can use any other online logger (like [comet.ml](https://www.comet.ml/site/)), but you need to implement the appropriate LoggerBase subclass on your own. To setup and use wandb, you need to do the following:
1. [Setup the wandb](https://docs.wandb.ai/quickstart) (or any other online logger).
2. Give your supervisor access to your project (ask him/her about the username.
3. Use the logger for all your trainings and provide the links to the final runs.

In [97]:
class WandbLogger(LoggerBase):
    def __init__(
            self, logdir: str | Path, project_name: str, experiment_name: str, **kwargs: Dict[str, Any]
    ):
        super().__init__(logdir)
        import wandb
        self.project_name = project_name
        self.experiment_name = experiment_name
        self.kwargs = kwargs
        self.run = wandb.init(
            dir=self.logdir,
            project=self.project_name,
            name=self.experiment_name,
            **self.kwargs,
        )

    def log_metrics(self, metrics: Dict[str, Any], prefix: str):
        metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
        self.run.log(metrics)

    def close(self):
        self.run.finish()

## Task 6. Train GraphSAGE (2 points).
1. Tune hyperparameters of a GNN with `SAGELayer` as MPNN layer to obtain at most 2.0 MAE on the validation set. You can modify the GNN/MPNN architecture, so it uses some regularization tricks like dropout or batch norm. Don't change the validation batch size. If your validation MAE is in (2.0, 2.5], you can obtain 1 point.
2. Report the obtained MAE on the validation and test set (only the former need to be lower than 2.0 MAE).
3. Provide the link to the final run: [your link]

In [98]:
### Example code for training. You can modify it for easier grid-searching.

In [99]:
from datetime import datetime


def get_time_stamp() -> str:
    return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [None]:
from torchmetrics import MeanAbsoluteError as MAE
from torchmetrics import MeanSquaredError as MSE
from torchmetrics import PearsonCorrCoef as PCC

metrics = {
    "mae": MAE(),
    "mse": MSE(),
    "pcc": PCC(),
}

model = GNN(
    node_features_size=node_featurizer.feat_size(),
    edge_features_size=edge_featurizer.feat_size(),
    hidden_size=256,
    output_size=1,
    mpnn_layer_cls=SAGELayer,
    mpnn_n_layers=6,
    readout_cls=MeanReadout,
    mpnn_layer_kwargs={}
)

trainer = Trainer(
    run_dir="experiments",
    train_dataset=train,
    valid_dataset=valid,
    train_metrics=metrics,
    valid_metrics=metrics,
    train_batch_size=32,
    model=model,
    logger=WandbLogger(
        logdir="runs/mpnn",
        project_name="mldd23",
        experiment_name=f"sage_{get_time_stamp()}",
    ),
    optimizer_kwargs={"lr": 1e-4},
    n_epochs=50,
    device="cpu",
    valid_every_n_epochs=1,
)

valid_metrics = trainer.train()
test_metrics = trainer.test(test)
trainer.close()
print(f"Validation metrics: {valid_metrics}")
print(f"Test metrics: {test_metrics}")

## Task 7. Train GIN (2 points).
1. Tune hyperparameters of a GNN with `GINLayer` as MPNN layer to obtain at most 2.0 MAE on the validation set. You can modify the GNN/MPNN architecture, so it uses some regularization tricks like dropout or batch norm. Don't change the validation batch size. If your validation MAE is in (2.0, 2.5], you can obtain 1 point.
2. Report the obtained MAE on the validation and test set (only the former need to be lower than 2.0 MAE).
3. Provide the link to the final run: [your link]

## Task 8. Train GINE (2 points).
1. Tune hyperparameters of a GNN with `GINELayer` as MPNN layer to obtain at most 2.0 MAE on the validation set. You can modify the GNN/MPNN architecture, so it uses some regularization tricks like dropout or batch norm. Don't change the validation batch size. If your validation MAE is in (2.0, 2.5], you can obtain 1 point.
2. Report the obtained MAE on the validation and test set (only the former need to be lower than 2.0 MAE).
3. Provide the link to the final run: [your link]

# Code optimization
Some pieces of code were written suboptimally. Your task is to slightly optimize them. 

## Task 9. Optimize SumReadout (1 point).
`SumReadout` was written using `to_dense_embeddings` function which does some unecessary memory allocations and computations. Your task is to rewrite the method using code from a bare torch library. Hint: `torch.index_add`.

In [None]:
class OptimizedSumReadout(ReadoutBase):
    def forward(self, 
                node_embeddings: torch.Tensor, 
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        pass

test_readout(OptimizedSumReadout, expected_sum_readout)

## Task 10. Optimize MeanReadout (1 point).
Your task is to rewrite the method using code from a bare torch library.

In [None]:
class OptimizedMeanReadout(ReadoutBase):
    def forward(self, 
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        pass


test_readout(OptimizedMeanReadout, expected_mean_readout)

## Task 11. Optimize SimpleMPNNLayer (1 point).
We can make our implementations of `SimpleMPNNLayer` layer (and basically any other MPNN layer) slightly faster by:
- reducing the costs of the message embedding (in the case of `SimpleMPNNLayer`, it's application of `self.linear_3`) from $O(m)$ to $O(n)$, where $m$ is the number of edges in the graph and $n$ is the number of nodes.
- removing quite expensive `to_dense_batch` call.

Your task is to apply the above optimizations.

In [None]:
class OptimizedSimpleMPNNLayer(MPNNLayerBase):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_1 = nn.Linear(hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.linear_3 = nn.Linear(hidden_size, hidden_size)

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        pass

test_mpnn_layer(OptimizedSimpleMPNNLayer, expected_simple_mpnn_output)