<a href="https://colab.research.google.com/github/memazouni/A-Comprehensive-ML-Workflow-for-HousePrices/blob/master/GraphSSL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Self-Supervised Learning for Graphs**
This colab serves as a tutorial on using self-supervised learning for graphs. Self-supervised learning is a class of unsupervised machine learning methods where the goal is to learn rich representations of unstructured data when we do not have access to any labels. This repository implements a variety of commonly used methods (augmentations, encoders, loss functions) for self-supervised learning on graphs. The codebase also includes the option of loading commonly used graph datasets for a variety of downstream tasks. It is built using [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) which is a library built on PyTorch for graph machine learning.

## Installation

The cells below ensure the correct installation of torch-geometric and clone the repository which has extra utitlity code which is required to train the models. The entire code repository for GraphSSL can be found on [Github](https://github.com/paridhimaheshwari2708/GraphSSL.git).

In [None]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-geometric
!git clone https://github.com/paridhimaheshwari2708/GraphSSL.git
%cd /content/GraphSSL/

Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 2.7 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 3.0 MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.12
Collecting torch-geometric
  Downloading torch_geometric-2.0.2.tar.gz (325 kB)
[K     |████████████████████████████████| 325 kB 21.2 MB/s 
Collecting rdflib
  Downloading rdflib-6.0.2-py3-none-any.whl (407 kB)
[K     |████████████████████████████████| 407 kB 41.1 MB/s 
Collecting

## Setting up arguments to train the self-supervised model

In [None]:
import os
import torch
import numpy as np
import torch.nn as nn

'''
Change these arguments to change either the dataset / model / loss function / types of augmentations.
The augmentations mentioned in augment_list shall be applied sequentially to generate a positive pair for contrastive training.
Make sure to not add too many augmentations as that would change the fundamental structure of the input graph.
'''

args = {
    "device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "save" : "ssl_model",
    "lr" : 0.001,
    "epochs" : 20,
    "batch_size" : 64,
    "num_workers" : 2,
    "dataset" : "proteins", # Choices are ["proteins", "enzymes", "collab", 
                            # "reddit_binary", "reddit_multi", "imdb_binary", 
                            # "imdb_multi", "dd", "mutag", "nci1"]
    "model" : "gcn", # choices are ["gcn", "gin", "resgcn", "gat", "graphsage", "sgc"]
    "feat_dim" : 128,
    "layers" : 3,
    "loss" : "infonce", # choices are ["infonce", "jensen_shannon"]
    "augment_list" : ["edge_perturbation", "node_dropping"],
    # choices are ["edge_perturbation", "diffusion", "diffusion_with_sample", 
    # "node_dropping", "random_walk_subgraph", "node_attr_mask"]
    "train_data_percent" : 1.0,
}

class AttributeDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttributeDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args = AttributeDict(args)

## Loading the dataset and creating dataloaders

The following cell deals with loading the data and splitting it into train, val and test splits. Further, we call our own custom dataloader which returns paired data -- the original graph and the positively augmented graph after applying the augmentations mentioned in the augment_list

In [None]:
from data import *

dataset, input_dim, num_classes = load_dataset(args.dataset)

# split the data into train / val / test sets
train_dataset, val_dataset, test_dataset = split_dataset(dataset, args.train_data_percent)

# build_loader is a dataloader which gives a paired sampled - the original x and the positively 
# augmented x obtained by applying the transformations in the augment_list as an argument
train_loader = build_loader(args, train_dataset, "train")
val_loader = build_loader(args, val_dataset, "val")
test_loader = build_loader(args, test_dataset, "test")

Downloading https://www.chrsmrrs.com/graphkerneldatasets/PROTEINS.zip
Extracting /tmp/TUDataset/PROTEINS/PROTEINS/PROTEINS.zip
Processing...
Done!


# samples in train subset: 779
# samples in val subset: 222
# samples in test subset: 112


## Initializing the model and optimizer
Here, the model comprises of only the GNN encoder.

In [None]:
from model import *

# easy initialization of the GNN model encoder to map graphs to embeddings needed for contrastive training 
model = Encoder(input_dim, args.feat_dim, n_layers=args.layers, gnn=args.model)
model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

## Model training

This code block deals with training a self-supervised encoder using contrastive methods to get embeddings from raw graph data. This part of the code does not require the training samples to be labelled.

In [None]:
from loss import *

def run(epoch, mode, dataloader):
	if mode == "train":
		model.train()
	elif mode == "val" or mode == "test":
		model.eval()

	contrastive_fn = eval(args.loss + "()")

	losses = []
	for data in dataloader:
		data.to(args.device)
	
		# readout_anchor is the embedding of the original datapoint x on passing through the model
		readout_anchor = model((data.x_anchor, 
								data.edge_index_anchor, data.x_anchor_batch))
	
		# readout_positive is the embedding of the positively augmented x on passing through the model
		readout_positive = model((data.x_pos, 
									data.edge_index_pos, data.x_pos_batch))

		# negative samples for calculating the contrastive loss is computed in contrastive_fn
		loss = contrastive_fn(readout_anchor, readout_positive)

		if mode == "train":
			# backprop
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

		# keep track of loss values
		losses.append(loss.item())

	# gather the results for the epoch
	epoch_loss = sum(losses) / len(losses)
	return epoch_loss

In [None]:
if not os.path.isdir(os.path.join("logs", args.save)):
    os.makedirs(os.path.join("logs", args.save))

best_train_loss, best_val_loss = float("inf"), float("inf")

for epoch in range(args.epochs):
    train_loss = run(epoch, "train", train_loader)
    val_loss = run(epoch, "val", val_loader)
    log = "Epoch {}, Train Loss: {:.3f}, Val Loss: {:.3f}"
    print(log.format(epoch, train_loss, val_loss))

    # save model
    is_best_loss = False
    if val_loss < best_val_loss:
        best_epoch, best_train_loss, best_val_loss, is_best_loss = \
                                            epoch, train_loss, val_loss, True

    model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, 
                          best_train_loss, best_val_loss, is_best_loss)

print("Train Loss at epoch {} (best model): {:.3f}".format(best_epoch, best_train_loss))
print("Val Loss at epoch {} (best model): {:.3f}".format(best_epoch, best_val_loss))

Epoch 0, Train Loss: 2.382, Val Loss: 3.534
Epoch 1, Train Loss: 2.335, Val Loss: 2.816
Epoch 2, Train Loss: 2.309, Val Loss: 2.611
Epoch 3, Train Loss: 2.306, Val Loss: 2.679
Epoch 4, Train Loss: 2.291, Val Loss: 2.421
Epoch 5, Train Loss: 2.274, Val Loss: 2.604
Epoch 6, Train Loss: 2.276, Val Loss: 2.409
Epoch 7, Train Loss: 2.260, Val Loss: 2.449
Epoch 8, Train Loss: 2.258, Val Loss: 2.430
Epoch 9, Train Loss: 2.262, Val Loss: 2.427
Epoch 10, Train Loss: 2.248, Val Loss: 2.400
Epoch 11, Train Loss: 2.252, Val Loss: 2.324
Epoch 12, Train Loss: 2.244, Val Loss: 2.442
Epoch 13, Train Loss: 2.243, Val Loss: 2.368
Epoch 14, Train Loss: 2.247, Val Loss: 2.393
Epoch 15, Train Loss: 2.252, Val Loss: 2.899
Epoch 16, Train Loss: 2.238, Val Loss: 2.420
Epoch 17, Train Loss: 2.238, Val Loss: 2.612
Epoch 18, Train Loss: 2.232, Val Loss: 2.644
Epoch 19, Train Loss: 2.225, Val Loss: 2.767
Train Loss at epoch 11 (best model): 2.252
Val Loss at epoch 11 (best model): 2.324


## Model testing

In [None]:
best_epoch, best_train_loss, best_val_loss = model.load_checkpoint(os.path.join("logs", args.save), optimizer)
model.eval()

test_loss = run(best_epoch, "test", test_loader)
print("Test Loss at epoch {}: {:.3f}".format(best_epoch, test_loss))

Test Loss at epoch 11: 2.334


# **Application on Downstream Task**
In this section of the Colab, we will use the pretrained embeddings obtained from self-supervised model and train only the final few layers for the end goal of performing graph classification. 

## Setting up arguments to train the classifier head

In [None]:
'''
Change these arguments to change either the dataset / model / train data percent
train_data_percent is the fraction of training data which has labels associated. The utility of self-supervised 
training can be seen when train_data_percent is low and we can't train the entire model end-to-end.
NOTE: The load argument will be the same as the save argument from the self-supervised training procedure
'''

args = {
    "device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "save" : "downstream_model",
    "load" : "ssl_model",
    "lr" : 0.001,
    "epochs" : 20,
    "batch_size" : 64,
    "num_workers" : 2,
    "dataset" : "proteins", # Choices are ["proteins", "enzymes", "collab", 
                            # "reddit_binary", "reddit_multi", "imdb_binary", 
                            # "imdb_multi", "dd", "mutag", "nci1"]
    "model" : "gcn", # choices are ["gcn", "gin", "resgcn", "gat", "graphsage", "sgc"]
    "feat_dim" : 128,
    "layers" : 3,
    "train_data_percent" : 1.0,
}

args = AttributeDict(args)

## Loading the dataset and creating dataloaders

In [None]:
dataset, input_dim, num_classes = load_dataset(args.dataset)

# split the data into train / val / test sets
train_dataset, val_dataset, test_dataset = split_dataset(dataset, args.train_data_percent)

# build_classification_loader is a dataloader which gives one graph at a time
train_loader = build_classification_loader(args, train_dataset, "train")
val_loader = build_classification_loader(args, val_dataset, "val")
test_loader = build_classification_loader(args, test_dataset, "test")

print("Dataset split: {} {} {}".format(len(train_dataset), len(val_dataset), len(test_dataset)))
print("Number of classes: {}".format(num_classes))

Dataset split: 779 222 112
Number of classes: 2


## Initializing the model and optimizer
Here, the model comprises of pretrained GNN encoder followed by classification layers.

In [None]:
# classification model is a GNN encoder followed by linear layer
model = GraphClassificationModel(input_dim, args.feat_dim, n_layers=args.layers, output_dim=num_classes, gnn=args.model, load=args.load)
model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

## Model training

In [None]:
def run(epoch, mode, dataloader):
	if mode == "train":
		model.train()
	elif mode == "val" or mode == "test":
		model.eval()

	# CrossEntropy loss since it is a classification task
	loss_fn = torch.nn.CrossEntropyLoss()

	losses = []
	correct = 0
	for data in dataloader:
		data.to(args.device)

		data_input = data.x, data.edge_index, data.batch
		labels = data.y

		# get class scores from model
		scores = model(data_input)

		# compute cross entropy loss
		loss = loss_fn(scores, labels)

		if mode == "train":
			# backprop
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

		# Keep track of loss and accuracy
		pred = scores.argmax(dim=1)
		correct += int((pred == labels).sum())
		losses.append(loss.item())

	# gather the results for the epoch
	epoch_loss = sum(losses) / len(losses)
	accuracy = correct / len(dataloader.dataset)
	return epoch_loss, accuracy

In [None]:
if not os.path.isdir(os.path.join("logs", args.save)):
    os.makedirs(os.path.join("logs", args.save))

best_train_loss, best_val_loss = float("inf"), float("inf")

for epoch in range(args.epochs):
    train_loss, train_acc = run(epoch, "train", train_loader)
    val_loss, val_acc = run(epoch, "val", val_loader)
    log = "Epoch {}, Train Loss: {:.3f}, Train Accuracy: {:.3f}, Val Loss: {:.3f}, Val Accuracy: {:.3f}"
    print(log.format(epoch, train_loss, train_acc, val_loss, val_acc))

    # save model
    is_best_loss = False
    if val_loss < best_val_loss:
        best_epoch, best_train_loss, best_val_loss, is_best_loss = epoch, train_loss, val_loss, True

    model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_loss, best_val_loss, is_best_loss)

print("Train Loss at epoch {} (best model): {:.3f}".format(best_epoch, best_train_loss))
print("Val Loss at epoch {} (best model): {:.3f}".format(best_epoch, best_val_loss))

Epoch 0, Train Loss: 3.884, Train Accuracy: 0.540, Val Loss: 1.971, Val Accuracy: 0.640
Epoch 1, Train Loss: 2.369, Train Accuracy: 0.629, Val Loss: 1.478, Val Accuracy: 0.622
Epoch 2, Train Loss: 1.658, Train Accuracy: 0.614, Val Loss: 1.615, Val Accuracy: 0.644
Epoch 3, Train Loss: 1.422, Train Accuracy: 0.660, Val Loss: 1.500, Val Accuracy: 0.640
Epoch 4, Train Loss: 1.217, Train Accuracy: 0.635, Val Loss: 1.501, Val Accuracy: 0.631
Epoch 5, Train Loss: 1.138, Train Accuracy: 0.632, Val Loss: 1.237, Val Accuracy: 0.667
Epoch 6, Train Loss: 1.085, Train Accuracy: 0.655, Val Loss: 0.950, Val Accuracy: 0.667
Epoch 7, Train Loss: 0.979, Train Accuracy: 0.668, Val Loss: 1.021, Val Accuracy: 0.667
Epoch 8, Train Loss: 0.988, Train Accuracy: 0.682, Val Loss: 1.285, Val Accuracy: 0.667
Epoch 9, Train Loss: 1.021, Train Accuracy: 0.661, Val Loss: 1.221, Val Accuracy: 0.658
Epoch 10, Train Loss: 1.058, Train Accuracy: 0.687, Val Loss: 1.295, Val Accuracy: 0.644
Epoch 11, Train Loss: 1.198, Tr

## Model testing

In [None]:
best_epoch, best_train_loss, best_val_loss = model.load_checkpoint(os.path.join("logs", args.save), optimizer)
model.eval()

test_loss, test_accuracy = run(best_epoch, "test", test_loader)
print("Test Loss at epoch {}: {:.3f}, Test Accuracy: {:.3f}".format(best_epoch, test_loss, test_accuracy))

Test Loss at epoch 6: 0.838, Test Accuracy: 0.696
