# Assignment — Graph Contrastive Learning

### Task 1. Augmentation (2 points)

In [None]:
#!pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html -q
!pip install dgl -q

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests

import dgl
from dgl.nn import GraphConv
from dgl.dataloading import GraphDataLoader
from dgl.data import DGLDataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Subset

from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression

from IPython.display import clear_output

<img src='https://raw.githubusercontent.com/netspractice/advanced_gnn/made2021/assignment_contrastive_learning/contrastive_learning.png' width=500>

Source: https://arxiv.org/abs/2103.00111

Contrastive learning aims to learn representations by maximizing feature consistency under differently augmented views, that exploit data- or task-specific augmentations. In a case of graph representation learning, there are some augmentation techniques that can be used to produce graph embeddings for downstream tasks, say classification.

Write a class `GraphAugmentation` with a function `transform` that takes a graph and returns an augmented graph. Types of augmentation:
* `drop_nodes` — randomly drops a share of nodes with a given ratio
* `pert_edges` — randomly perturbs (rewires) a share of edges with a given ratio
* `attr_mask` — randomly masks a share of node attributes with a given ratio and a name in `ndata` collection
* `rw_subgraph` — builds a subgraph based on random walk
* `identical` — the same graph, no augmentation

Augmentations are applied to graphs with self-loops, so keep self-loops during edges perturbation. Parallel edges are allowed after perturbation. A random walk subgraph is constructed by (1) adding a random starting node, (2) adding all its neighbors, (3) adding all neighbors of a random node in the subgraph and repeating the step 3 while number of nodes exceeds the threshold `(1 - ratio)`.

In [None]:
class GraphAugmentation():
    def __init__(self, type, ratio=0.2, node_feat='attr'):
        self.type = type
        self.ratio = ratio
        self.node_feat = node_feat
    
    def transform(self, g):
        if self.type == 'drop_nodes':
            return self.drop_nodes(g)
        elif self.type == 'pert_edges':
            return self.pert_edges(g)
        elif self.type == 'attr_mask':
            return self.attr_mask(g)
        elif self.type == 'rw_subgraph':
            return self.rw_subgraph(g)
        elif self.type == 'identical':
            return g
    
    def drop_nodes(self, g):
        # YOUR CODE HERE
        raise NotImplementedError()

    def pert_edges(self, g):
        # YOUR CODE HERE
        raise NotImplementedError()

    def attr_mask(self, g):
        # YOUR CODE HERE
        raise NotImplementedError()

    def rw_subgraph(self, g):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
g = dgl.rand_graph(100, 300)
g = g.remove_self_loop()
g = g.add_self_loop()
g.ndata['attr'] = torch.ones(100, 10)
g

In [None]:
ratio = 0.7
aug = GraphAugmentation('drop_nodes', ratio=ratio)
aug_g = aug.transform(g)
assert aug_g.ndata['attr'].shape[1] == g.ndata['attr'].shape[1]
assert aug_g.ndata['attr'].shape[0] < g.ndata['attr'].shape[0]
assert aug_g.ndata['attr'].shape[0] == int(g.number_of_nodes() * (1 - ratio))
G = nx.Graph(aug_g.to_networkx())
assert np.isclose(nx.laplacian_spectrum(G), 0).sum() > 1

In [None]:
aug = GraphAugmentation('pert_edges', ratio=0.2)
aug_g = aug.transform(g)
assert aug_g.ndata['attr'].shape == g.ndata['attr'].shape
assert aug_g.number_of_edges() == g.number_of_edges()
assert not torch.all(aug_g.adj().to_dense() == g.adj().to_dense())

In [None]:
aug = GraphAugmentation('attr_mask', ratio=0.2, node_feat='attr')
aug_g = aug.transform(g)
assert aug_g.ndata['attr'].shape == (100, 10)
mask = (aug_g.ndata['attr'][0, :] == 0).repeat(100, 1)
assert torch.all(aug_g.ndata['attr'][mask] == 0)
assert torch.all(aug_g.ndata['attr'][~mask] == 1)

In [None]:
aug = GraphAugmentation('rw_subgraph', ratio=0.7)
aug_g = aug.transform(g)
assert aug_g.ndata['attr'].shape[1] == g.ndata['attr'].shape[1]
assert aug_g.ndata['attr'].shape[0] < g.ndata['attr'].shape[0]
G = nx.Graph(aug_g.to_networkx())
assert np.isclose(nx.laplacian_spectrum(G), 0).sum() == 1

### Task 2. Contrastive dataset (2 points)

We will fed augmented graphs into encoder during training to obtain graph embeddings. Let us prepare a graph contrastive dataset class so that each element in the dataset will represent augmented graphs and a label.

Write a class `ContrastiveDataset` with a function `__getitem__` that takes a graph's index and returns a tuple:
* an initial graph
* a graph after the first augmentation
* a graph after the second augmentation
* a label

In [None]:
class ContrastiveDataset(DGLDataset):
    def __init__(self, filename, augmentations):
        self.filename = filename
        self.graphs = None
        self.labels = None
        self.augmentations = augmentations
        assert len(self.augmentations) == 2
        super().__init__(name=filename)

    def process(self):
        graphs, graph_data = dgl.load_graphs(self.filename)
        self.graphs = graphs
        self.labels = graph_data['labels']

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        # YOUR CODE HERE
        raise NotImplementedError()

PROTEINS is a dataset with 1113 proteins where nodes are secondary structure elements and there is an edge between two nodes if they are neighbors in the amino-acid sequence or in 3D space. It has 3 discrete labels, representing helix, sheet or turn. Proteins are divided into two classes: enzymes and non-enzymes. Source: https://arxiv.org/abs/2007.08663.

Let us create a dataset with dropping nodes and masking attributes augmentations.

In [None]:
url = 'https://github.com/netspractice/advanced_gnn/raw/made2021/assignment_contrastive_learning/proteins.bin'
open('proteins.bin', 'wb').write(requests.get(url).content);

augmentations = []
augmentations.append(GraphAugmentation('drop_nodes', ratio=0.1))
augmentations.append(GraphAugmentation('attr_mask', ratio=0.4, node_feat='attr'))
dataset = ContrastiveDataset(filename='proteins.bin', augmentations=augmentations)
N = len(dataset)
N

In [None]:
g, aug_g1, aug_g2, label = dataset[0]
assert g.ndata['attr'].shape == (42, 3)
assert aug_g1.ndata['attr'].shape == (37, 3)
assert g.ndata['attr'].sum() <= 42

Since we perform the random walk subgraph augmentation, we want to make sure all initial graphs are connected.

Write a function `connected_subset` that takes an initial dataset and returns a `torch.utils.data.dataset.Subset` with connected graphs only.

In [None]:
def connected_subset(dataset):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
c_dataset = connected_subset(dataset)
N = len(c_dataset)
assert N == 1067

Let us look at some graphs in the dataset.

In [None]:
colors = ['tab:orange', 'tab:green']
plt.figure(figsize=(12, 12))
np.random.seed(0)
for i in range(16):
    plt.subplot(4, 4, i+1)
    g, _, _, l = c_dataset[np.random.randint(N)]
    g = nx.Graph(g.to_networkx())
    g.remove_edges_from(nx.selfloop_edges(g))
    nx.draw_kamada_kawai(g, node_size=30, node_color=colors[l])
    plt.title('enzymes' if l == 1 else 'non-enzymes')

### Task 3. GCN Encoder (1 point)

Let an encoder be the two-layers GCN (`GraphConv` in `dgl`) with mean graph pooling and two-layers MLP projection head. All layers except of input and output ones have `hidden_dim` dimensionality. Apply ReLU as an activation function.

Write a class `GCNEncoder` with a function `forward` that takes a batch of graphs, node attrubute name in `ndata` collection and returns graph embeddings.

In [None]:
class GCNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, g, node_feat):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
batch = []
for _ in range(2):
    g = dgl.rand_graph(100, 300)
    g = g.remove_self_loop()
    g = g.add_self_loop()
    g.ndata['attr'] = torch.ones(100, 10)
    batch.append(g)
batch = dgl.batch(batch)

encoder = GCNEncoder(input_dim=10, hidden_dim=32, output_dim=16)
emb = encoder(batch, 'attr')
assert emb.shape == (2, 16)

### Task 4. Classification on untrained encoder (1 point)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

Let us check the logistic regression model on the untrained encoder output.

Write a function `train_test_split` that splits the dataset into train and test sets by given one-hot encoded vectors `train_idx` and `test_idx`.

In [None]:
def train_test_split(c_dataset, train_idx, test_idx):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
np.random.seed(0)
ratio = [0.9, 0.1] # train test ratio
split_idx = ['train'] * int(ratio[0] * N) \
    + ['test'] * int(ratio[1] * N)
split_idx = np.random.permutation(split_idx)
train_idx = np.where(split_idx == 'train')[0]
test_idx = np.where(split_idx == 'test')[0]

graph_train, graph_test, y_train, y_test = train_test_split(
    c_dataset, train_idx, test_idx)
assert graph_train.ndata['attr'].shape == (36363, 3)
assert graph_test.ndata['attr'].shape == (3740, 3)
assert y_train.shape == (960, )
assert y_test.shape == (106, )

Let us check the classification score and look at tSNE visualization.

In [None]:
encoder = GCNEncoder(input_dim=3, hidden_dim=32, output_dim=16)
encoder.to(device)

In [None]:
def classification_score(graph_train, graph_test, y_train, y_test, encoder, show=True):
    
    with torch.no_grad():
        X_train = encoder(graph_train, 'attr').cpu()
        X_test = encoder(graph_test, 'attr').cpu()
    
    clf = LogisticRegression()
    clf.fit(X_train, y_train.cpu())
    score = clf.score(X_test, y_test.cpu())
    
    if show:
        plt.figure(figsize=(10, 6))
        cmap = plt.cm.Set1_r
        dec = TSNE(n_components=2)
        xy_emb = dec.fit_transform(X_train)
        plt.scatter(xy_emb[:, 0], xy_emb[:, 1], c=y_train.cpu(), cmap=cmap, s=5)
        plt.title('tSNE visualization')
        plt.show()
        print('Accuracy: {:.4f}'.format(score))
    
    return score

In [None]:
score = classification_score(graph_train, graph_test, y_train, y_test, encoder)

### Task 5. Contrastive loss (2 points)

In graph contrastive learning, pre-training is performed through maximizing the agreement between two augmented views of the same graph via a contrastive loss in the latent space. Contrastive loss function is defined to enforce maximizing the consistency between positive pairs $z_i$, $z_j$ (the same graph under different augmentations) compared with negative pairs. Here we utilize the NT-Xent Loss that is defined for $n$-th graph in a batch of $N$ graphs as follows:

$$l_{n}=-\log \frac{\exp \left(\text{sim}(z_{n, i}, z_{n, j}) / \tau \right)}{\sum_{n'=1, n' \neq n}^{N} \exp \left( \text{sim}(z_{n, i}, z_{n', j}) / \tau \right)}$$

where $\text{sim}$ is cosine similarity $\text{sim}(z_i, z_j) = z_i^\top z_j / (\| z_i \| \cdot \| z_j \|)$ and $\tau$ is a temperature parameter.

Source: https://arxiv.org/pdf/2010.13902.pdf

Write a function `ntxent` that takes a batch of agmented graph embeddings `x1` and a batch of agmented graph embeddings `x2` and returns mean loss value among all graphs $L = \frac{1}{N}\sum_{n=1}^N l_n$.

_Hint: it is possible to use matrix operations only, with no loops._

In [None]:
def ntxent(x1, x2, tau=0.1):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
x1 = torch.tensor([[1., 0.], [0., 1.]])
x2 = torch.tensor([[1., 0.], [0., 1.]])
assert ntxent(x1, x2) == -10

x1 = torch.tensor([[1., 0.], [0., 1.]])
x2 = torch.tensor([[0., 1.], [1., 0.]])
assert ntxent(x1, x2) == 10

torch.manual_seed(0)
x1 = torch.randn(128, 16)
x2 = torch.randn(128, 16)
assert round(ntxent(x1, x2).item(), 4) == 7.191

### Task 6. Training loop (1 point)

Let us train encoder under contrastive loss and then check classification score.

In [None]:
loader = GraphDataLoader(
    c_dataset,
    batch_size=64,
    drop_last=False,
    shuffle=True)

In [None]:
encoder = GCNEncoder(input_dim=3, hidden_dim=32, output_dim=16)
encoder.to(device)
opt = Adam(encoder.parameters(), lr=0.005)

Write a function `train` that takes augmented batches, makes optimization step and returns a loss value.

In [None]:
def train(encoder, aug_batch1, aug_batch2, opt):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
for batch, aug_batch1, aug_batch2, label in loader:
    break
loss_item = train(encoder, aug_batch1, aug_batch2, opt)
assert type(loss_item) == float
assert loss_item > 0

Here is a training loop that accumulates mean loss per epoch.

In [None]:
loss_vals = []
n_epochs = 30
for i in range(n_epochs):
    loss_epoch = []
    for batch, aug_batch1, aug_batch2, label in loader:
        loss_item = train(encoder, aug_batch1, aug_batch2, opt)
        loss_epoch.append(loss_item)
    loss_vals.append(sum(loss_epoch)/len(loss_epoch))
    plt.plot(loss_vals)
    plt.title('Contrastive loss. Epoch: {}/{}'.format(i+1, n_epochs))
    plt.show();
    clear_output(wait=True)

In [None]:
score = classification_score(graph_train, graph_test, y_train, y_test, encoder)
assert score > 0.65

As we see, we can noticeably improve classification score using self-supervised learning.

### Task 7. Augmentation comparison (1 point)

Here we aim to compare augmentation techniques and conclude which pair is better for PROTEINS dataset.

Write a function `run` that takes a filename with proteins, number of epochs and a list of types of augmentation. It returns a np.array with a classification score matrix where rows are first augmentation, columns are second augmentation. Since the matrix is asymptotically symmetric, calculate the upper triangle values only.

It can take time. To speed up the evaluation, return calculated score matrix without actual training:
```
def run(filename, n_epochs, augs):
    scores = [[0.5, 0.5, 0.5, 0.5, 0.5], [0, 0.5, 0.5, 0.5], ...
    return scores

    ### ACTUAL TRAINING
```

In [None]:
def run(filename, n_epochs, augs):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
augs = ['drop_nodes', 'pert_edges', 'attr_mask', 'rw_subgraph', 'identical']
res = run('proteins.bin', n_epochs=30, augs=augs)
symm = (res.T + res - np.diag(res[range(5), range(5)])).sum(0)
assert np.all((res > 0).sum(0) == np.arange(5) + 1)
assert np.all(res[res > 0] > 0.6)
assert symm[0] > symm[4]
assert symm[2] > symm[4]

In [None]:
pd.DataFrame(res, index=augs, columns=augs).round(4)