# Tutorial for sslgraph

In this hand-on code tutorial, we will show how to reproduce and develop self-supervised learning (SSL) methods using our DIG library. Specifically, we show how to implement existing SSL methods, how to develop and evaluate your own methods, and how to extract embeddings generated by SSL methods.

## 1. Implementation of existing methods

In [3]:
from dig.sslgraph.utils import Encoder
from dig.sslgraph.dataset import get_dataset, get_node_dataset
from dig.sslgraph.evaluation import GraphUnsupervised, NodeUnsupervised, GraphSemisupervised
from dig.sslgraph.method import InfoGraph, MVGRL, GRACE, GraphCL

### 1.1 Unsupervised learning -- graph level

You need to download a graph level dataset, define a model, and then evaluate following our standard evluation process. Here we show two examples of unsupervised graph level tasks by existing methods, InfoGraph and MVGRL. 

In [4]:
dataset = get_dataset('MUTAG', task='unsupervised')

embed_dim = 512
encoder = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 
                  n_layers=4, gnn='gin', node_level=True)
infograph = InfoGraph(g_dim=embed_dim*4, n_dim=embed_dim)

evaluator = GraphUnsupervised(dataset, log_interval=1)
evaluator.setup_train_config(batch_size=256, p_lr=0.0001, p_epoch=20)
evaluator.evaluate(learning_model=infograph, encoder=encoder)

Pretraining: epoch 20: 100%|█████████████████████████████| 20/20 [05:47<00:00, 17.37s/it, loss=0.044081]

Best epoch 20: acc 0.8936 +/-(0.0555)





(0.8935672514619883, 0.05553332204466091)

In [6]:
dataset = get_dataset('MUTAG', task='unsupervised')

embed_dim = 512
encoder_adj = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 
                      n_layers=4, gnn='gcn', node_level=True, act='prelu')
encoder_diff = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 
                       n_layers=4, gnn='gcn', node_level=True, act='prelu', edge_weight=True)
mvgrl = MVGRL(g_dim=embed_dim*4, n_dim=embed_dim, diffusion_type='ppr', device=7)

evaluator = GraphUnsupervised(dataset, log_interval=2, device=7)
evaluator.setup_train_config(batch_size=256, p_lr=0.001, p_epoch=20)
evaluator.evaluate(learning_model=mvgrl, encoder=[encoder_adj, encoder_diff])

Pretraining: epoch 20: 100%|████████████████████████████| 20/20 [00:53<00:00,  2.67s/it, loss=-0.021439]

Best epoch 18: acc 0.8991 +/-(0.0981)





(0.899122807017544, 0.09813851723297315)

### 1.2 Unsupervised learning -- node level

You need to download a node level dataset, define a model, and then evaluate following our standard evluation process. Here we show one example of unsupervised node level task by an existing method GRACE. 

In [7]:
dataset = get_node_dataset('cora')

embed_dim = 128
encoder = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 
                  n_layers=2, gnn='gcn', node_level=True, graph_level=False)
grace = GRACE(dim=embed_dim, dropE_rate_1=0.2, dropE_rate_2=0.4, 
              maskN_rate_1=0.3, maskN_rate_2=0.4, tau=0.4, device=3)

evaluator = NodeUnsupervised(dataset, device=3, log_interval=100)
evaluator.setup_train_config(p_lr=0.0005, p_epoch=2000, p_weight_decay=1e-5, comp_embed_on='cpu')
evaluator.evaluate(learning_model=grace, encoder=encoder)

Pretraining: epoch 2000: 100%|███████████████████████| 2000/2000 [07:47<00:00,  4.28it/s, loss=7.770879]

Best epoch 1300: acc 0.8249 (+/- 0.0048).





0.8249000310897827

### 1.3 Semi-supervised learning -- graph level & grid search

You need to download a graph level dataset in semisupervised mode, define a model, and then evaluate following our standard evluation process. Here we show one example of semi-supervised graph level task by an existing methods GraphCL. For semi-supervised setting, GraphCL uses ResGCN. Available augmentation includes: dropN, maskN, permE, subgraph, random[2-4]. In this example, we use a label rate of 10%. You can also perform evaluation with grid search.

In [8]:
dataset, dataset_pretrain = get_dataset('DD', task='semisupervised')
feat_dim = dataset[0].x.shape[1]
embed_dim = 128

encoder = Encoder(feat_dim, embed_dim, n_layers=3, gnn='resgcn')
graphcl = GraphCL(embed_dim, aug_1='subgraph', aug_2='dropN')

evaluator = GraphSemisupervised(dataset, dataset_pretrain, label_rate=0.1)
evaluator.evaluate(learning_model=graphcl, encoder=encoder)

Pretraining: epoch 100: 100%|██████████████████████████| 100/100 [09:48<00:00,  5.88s/it, loss=2.847807]
Fold 1, finetuning: 100%|████████████████| 100/100 [00:12<00:00,  8.11it/s, acc=0.7119, val_loss=1.6124]
Fold 2, finetuning: 100%|████████████████| 100/100 [00:11<00:00,  8.68it/s, acc=0.7797, val_loss=0.7528]
Fold 3, finetuning: 100%|████████████████| 100/100 [00:11<00:00,  8.48it/s, acc=0.7119, val_loss=1.3771]
Fold 4, finetuning: 100%|████████████████| 100/100 [00:11<00:00,  8.56it/s, acc=0.7797, val_loss=0.8696]
Fold 5, finetuning: 100%|████████████████| 100/100 [00:11<00:00,  8.87it/s, acc=0.7458, val_loss=1.1965]
Fold 6, finetuning: 100%|████████████████| 100/100 [00:12<00:00,  8.25it/s, acc=0.7542, val_loss=1.1717]
Fold 7, finetuning: 100%|████████████████| 100/100 [00:12<00:00,  7.96it/s, acc=0.7542, val_loss=1.2606]
Fold 8, finetuning: 100%|████████████████| 100/100 [00:11<00:00,  8.49it/s, acc=0.7203, val_loss=1.2794]
Fold 9, finetuning: 100%|████████████████| 100/100 [00:

(0.7699044942855835, 0.03682943806052208)

In [29]:
evaluator.grid_search(learning_model=graphcl, encoder=encoder,
                      p_lr_lst=[0.01,0.001], p_epoch_lst=[20,40])

Pretraining: epoch 20: 100%|████████████████████| 20/20 [01:41<00:00,  5.06s/it, loss=3.040437]
Fold 1, finetuning: 100%|███████| 100/100 [00:07<00:00, 12.93it/s, acc=0.6949, val_loss=1.5951]
Fold 2, finetuning: 100%|███████| 100/100 [00:07<00:00, 13.15it/s, acc=0.7712, val_loss=0.9050]
Fold 3, finetuning: 100%|███████| 100/100 [00:07<00:00, 12.80it/s, acc=0.7119, val_loss=1.5409]
Fold 4, finetuning: 100%|███████| 100/100 [00:07<00:00, 13.13it/s, acc=0.7712, val_loss=0.8240]
Fold 5, finetuning: 100%|███████| 100/100 [00:07<00:00, 13.09it/s, acc=0.7203, val_loss=1.2548]
Fold 6, finetuning: 100%|███████| 100/100 [00:07<00:00, 12.99it/s, acc=0.7966, val_loss=1.1412]
Fold 7, finetuning: 100%|███████| 100/100 [00:08<00:00, 12.40it/s, acc=0.8136, val_loss=0.9092]
Fold 8, finetuning: 100%|███████| 100/100 [00:07<00:00, 12.58it/s, acc=0.6356, val_loss=1.6541]
Fold 9, finetuning: 100%|███████| 100/100 [00:07<00:00, 12.97it/s, acc=0.6154, val_loss=4.1965]
Fold 10, finetuning: 100%|██████| 100/10

Best paras: 20 epoch, lr=0.010000, acc=0.7623





(0.7622699737548828, 0.048677004873752594, (0.01, 20))

## 2. Develop & evaluate your own method

You can always write your own code to do flexible evlauation of the above defined contrastive methods. However, we provide pre-implemented evluation tools for more convenient evaluation. The tool works with most datasets from pytorch-geometric. 

In [9]:
from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask
from dig.sslgraph.method import Contrastive
from dig.sslgraph.dataset import get_dataset
from dig.sslgraph.utils import Encoder
from dig.sslgraph.evaluation import GraphSemisupervised

class SSLModel(Contrastive):
    def __init__(self, z_dim, mask_ratio, **kwargs):

        objective = "JSE"
        proj="MLP"
        mask_i = NodeAttrMask(mask_ratio=mask_ratio)
        mask_j = NodeAttrMask(mask_ratio=mask_ratio)
        views_fn = [mask_i, mask_j]

        super(SSLModel, self).__init__(objective=objective,
                                    views_fn=views_fn,
                                    z_dim=z_dim,
                                    proj=proj,
                                    node_level=False,
                                    **kwargs)

    def train(self, encoder, data_loader, optimizer, epochs, per_epoch_out=False):
        for enc, proj in super(SSLModel, self).train(encoder, data_loader,
                                                    optimizer, epochs, per_epoch_out):
            yield enc

dataset, dataset_pretrain = get_dataset('NCI1', task='semisupervised')
feat_dim = dataset[0].x.shape[1]
embed_dim = 128

encoder = Encoder(feat_dim, embed_dim, n_layers=3, gnn='resgcn')
ssl_model = SSLModel(z_dim=embed_dim, mask_ratio=0.1)
evaluator = GraphSemisupervised(dataset, dataset_pretrain, label_rate=0.01)
evaluator.evaluate(learning_model=ssl_model, encoder=encoder)

Pretraining: epoch 100: 100%|█████████████████████████| 100/100 [06:59<00:00,  4.20s/it, loss=-0.704380]
Fold 1, finetuning: 100%|████████████████| 100/100 [00:20<00:00,  4.82it/s, acc=0.6375, val_loss=2.6667]
Fold 2, finetuning: 100%|████████████████| 100/100 [00:21<00:00,  4.69it/s, acc=0.6448, val_loss=9.3357]
Fold 3, finetuning: 100%|████████████████| 100/100 [00:20<00:00,  4.84it/s, acc=0.5474, val_loss=2.6623]
Fold 4, finetuning: 100%|████████████████| 100/100 [00:20<00:00,  4.79it/s, acc=0.5888, val_loss=3.1011]
Fold 5, finetuning: 100%|████████████████| 100/100 [00:21<00:00,  4.75it/s, acc=0.5937, val_loss=5.9771]
Fold 6, finetuning: 100%|████████████████| 100/100 [00:20<00:00,  4.91it/s, acc=0.6715, val_loss=1.6416]
Fold 7, finetuning: 100%|████████████████| 100/100 [00:20<00:00,  4.80it/s, acc=0.6204, val_loss=2.0579]
Fold 8, finetuning: 100%|████████████████| 100/100 [00:20<00:00,  4.90it/s, acc=0.6277, val_loss=3.3737]
Fold 9, finetuning: 100%|████████████████| 100/100 [00:

(0.6204379796981812, 0.042001646012067795)

## 3. Extract embeddings for other tasks

You can also extract graph embeddings generated by exisisting SSL methods, and then apply the extracted graph embeddings to any other downstream tasks.

In [13]:
from dig.sslgraph.dataset import get_dataset
from torch_geometric.loader import DataLoader
import torch

embed_dim = 512
encoder = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 
                  n_layers=4, gnn='gin', node_level=True)
infograph = InfoGraph(g_dim=embed_dim*4, n_dim=embed_dim)

dataset = get_dataset('MUTAG', task='unsupervised')
pretrain_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
optimizer = torch.optim.Adam(encoder.parameters(), lr=0.01, weight_decay=0)
infograph.train(encoder, pretrain_dataloader, optimizer, epochs=20)

<generator object InfoGraph.train at 0x7f5f80f7c510>

Here's the embedding of the first graph in the dataset.

In [14]:
embed = encoder(dataset[0])