# Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from sklearn.metrics import f1_score

import torch
import torch.nn as nn

def format_pytorch_version(version):
    return version.split("+")[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
    return "cu" + version.replace(".", "")

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

# !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-geometric

import torch_geometric.nn as graphnn
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\nDevice: ", device)

# Dataset

We use the Protein-Protein Interaction (PPI) network dataset which includes:
- 20 graphs for training 
- 2 graphs for validation
- 2 graphs for testing

One graph of the PPI dataset has on average 2372 nodes. Each node:
- 50 features : positional gene sets / motif gene / immunological signatures ...
- 121 (binary) labels : gene ontology sets (way to classify gene products like proteins).

**This problem aims to predict, for a given PPI graph, the correct node's labels**.

**It is a node (multi-level) classification task** (trained using supervised learning). 

In [None]:
BATCH_SIZE = 2

train_dataset = PPI(root="data/GNN/train", split="train")
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

val_dataset = PPI(root="data/GNN/val", split="val")
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

test_dataset = PPI(root="data/GNN/test", split="test")
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

n_features, n_classes = train_dataset[0].x.shape[1], train_dataset[0].y.shape[1]

print("Number of samples in the train dataset: ", len(train_dataset))
print("Number of samples in the val dataset: ", len(test_dataset))
print("Number of samples in the test dataset: ", len(test_dataset))

print("Number of features per node: ", n_features)
print("Number of classes per node: ", n_classes)

# Models

In [None]:
def evaluate(model, loss_fcn, device, dataloader):
    score_list_batch = []

    model.eval()
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        output = model(batch.x, batch.edge_index)
        loss_test = loss_fcn(output, batch.y)
        predict = np.where(output.detach().cpu().numpy() >= 0, 1, 0)
        score = f1_score(batch.y.cpu().numpy(), predict, average="micro")
        score_list_batch.append(score)

    return np.array(score_list_batch).mean()


def train(
    model, loss_fcn, device, optimizer, max_epochs, train_dataloader, val_dataloader
):

    epoch_list = []
    scores_list = []

    # loop over epochs
    for epoch in range(max_epochs):
        model.train()
        losses = []
        # loop over batches
        for i, train_batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            train_batch_device = train_batch.to(device)
            # logits is the output of the model
            logits = model(train_batch_device.x, train_batch_device.edge_index)
            # compute the loss
            loss = loss_fcn(logits, train_batch_device.y)
            # optimizer step
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        loss_data = np.array(losses).mean()
        print("Epoch {:05d} | Loss: {:.4f}".format(epoch + 1, loss_data))

        if epoch % 5 == 0:
            # evaluate the model on the validation set
            # computes the f1-score (see next function)
            score = evaluate(model, loss_fcn, device, val_dataloader)
            print("F1-Score: {:.4f}".format(score))
            scores_list.append(score)
            epoch_list.append(epoch)

    return epoch_list, scores_list

In [None]:
def plot_f1_score(epoch_list, scores):
    plt.figure(figsize=[10, 5])
    plt.plot(epoch_list, scores)
    plt.title("Evolution of F1S-Score w.r.t epochs")
    plt.ylim([0.0, 1.0])
    plt.show()


### Graph Convolution

https://arxiv.org/pdf/1609.02907.pdf

In [None]:
class ConvGraphModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.graphconv1 = graphnn.GCNConv(input_size, hidden_size)
        self.graphconv2 = graphnn.GCNConv(hidden_size, hidden_size)
        self.graphconv3 = graphnn.GCNConv(hidden_size, output_size)

        self.elu = nn.ELU()

    def forward(self, x, edge_index):
        x = self.graphconv1(x, edge_index)
        x = self.elu(x)
        x = self.graphconv2(x, edge_index)
        x = self.elu(x)
        x = self.graphconv3(x, edge_index)

        return x
    
    
convolution_model = ConvGraphModel(
    input_size=n_features, hidden_size=256, output_size=n_classes
).to(device)

convolution_model

In [None]:
max_epochs = 200

loss_fcn = nn.BCEWithLogitsLoss() # sigmoid included 
optimizer = torch.optim.Adam(convolution_model.parameters(), lr=0.005)

epoch_list, convolution_model_scores = train(
    convolution_model,
    loss_fcn,
    device,
    optimizer,
    max_epochs,
    train_dataloader,
    val_dataloader,
)

In [None]:
score_test = evaluate(convolution_model, loss_fcn, device, test_dataloader)
print("Convolution Model : F1-Score on the test set: {:.4f}".format(score_test))

plot_f1_score(epoch_list, convolution_model_scores)

### Graph Attention

https://arxiv.org/pdf/1710.10903.pdf

In [None]:
class AttGraphModel(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):

        super().__init__()

        """Default initialization is Glorot"""

        self.graphat1 = graphnn.GATConv(input_size, hidden_size, heads=4)
        self.graphat2 = graphnn.GATConv(4*hidden_size, hidden_size, heads=4)

        """Last layer averages features"""

        self.graphat3 = graphnn.GATConv(4*hidden_size, output_size, heads=6, concat=False)

        self.elu = nn.ELU()

    def forward(self, x, edge_index):

        """Forward with skipped connections"""

        x = self.graphat1(x, edge_index)
        x1 = self.elu(x)

        x = self.graphat2(x1, edge_index) 
        x2 = self.elu(x+x1)
        
        x = self.graphat3(x2, edge_index)

        return x
    

attention_model = AttGraphModel(
    input_size=n_features, hidden_size=256, output_size=n_classes
).to(device)

attention_model

In [None]:
max_epochs = 200

loss_fcn = nn.BCEWithLogitsLoss() # sigmoid included 
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.005)

epoch_list, attention_model_scores = train(attention_model, loss_fcn, device, optimizer, max_epochs, train_dataloader, val_dataloader)

In [None]:
score_test = evaluate(convolution_model, loss_fcn, device, test_dataloader)
print("Attention Model : F1-Score on the test set: {:.4f}".format(score_test))

plot_f1_score(epoch_list, convolution_model_scores)

# Comparison

In [None]:
def plot_f1_scores(epoch_list, convolution_model_scores, attention_model_scores):
    plt.figure(figsize=[10, 5])
    plt.plot(epoch_list, convolution_model_scores, "b", label="Basic Model")
    plt.plot(epoch_list, attention_model_scores, "r", label="Student Model")
    plt.title("Evolution of f1 score w.r.t epochs")
    plt.ylim([0.0, 1.0])
    plt.ylabel("Epochs")
    plt.xlabel("F1-Score")
    plt.legend()
    plt.show()

In [None]:
score_test = evaluate(attention_model, loss_fcn, device, test_dataloader)
print("Student Model : F1-Score on the test set: {:.4f}".format(score_test))

plot_f1_scores(epoch_list, convolution_model_scores, attention_model_scores)

# Conclusion

We tried to reproduce the Graph Attention Networks paper. The attention mechanism gives an advantage compared to GCN because it allows to attributes different weights to the different neighbors of a node rather than doing a simple summation and therefore have more granularity on the interesting features. GATConv can choose to emphasize or de-emphasize features from neighbors whereas GCNConv treats all neighbors equally.   
Furthermore GATConv is able to capture long-range dependencies with the attention mechanism, contrary to GCNConv that only focuses on immediate neighbors.
Finally, it can be worth noting that GAT can be parallelized when multiple heads are used, leading to even more efficiency without big time addition.  

As guessed above, the result are better in term of F1 score with the GAT layers.  
SOA is at 0.973 which is not so far away from our result.


**The oversmoothing problem**  

Oversmoothing in GNN corresponds to the phenomenon where the increase of the network depth leads to homogeneous nodes representations. The nodes features tend to become more similar.  
In *Rush, Bronstein, Michra, A survey of oversmoothing in GNN, 2023* it is defined as "the exponential
convergence of all node features towards the same constant value as the number of layers in the GNN increases".  
The oversmoothing phenomenon is an issue because as nodes become more and more similar, we lose the ability to capture the differences between them and therefore lose performance.

As per the same article, there are solutions to mitigate over-smoothing like normalization and regularization, or residual connections.    

In our code there is no no regularization nor dropout. However by essence the attention mechanism adress the oversmoothing issue by capturing long-range depencies with different importance weights per nodes and per heads. Furthermore, skipping connection can help mitigate over-smoothing.  
Based on these observations, we can say that the model is rather robust to oversmoothing.