## Example of GraphCL with grid search

In [1]:
import sys
sys.path.insert(0,'..')
sys.path.insert(0,'../..')
import torch

from dig.sslgraph.utils import Encoder
from dig.sslgraph.evaluation import GraphUnsupervised
from dig.threedgraph.dataset import MoleculeNet
from dig.sslgraph.method import GraphCL
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import csv
import pandas as pd

# Tox21

In [None]:
esol = MoleculeNet(root='dataset/', name='esol')  # regression
freesolv = MoleculeNet(root='dataset/', name='freesolv')  # regression
lipo = MoleculeNet(root='dataset/', name='lipo')  # regression
hiv = MoleculeNet(root='dataset/', name='hiv')    # binary
bace = MoleculeNet(root='dataset/', name='bace')  # binary
bbbp = MoleculeNet(root='dataset/', name='bbbp')  # binary
clintox = MoleculeNet(root='dataset/', name='clintox')  # binary
tox21 = MoleculeNet(root='dataset/', name='tox21') # 12
sider = MoleculeNet(root='dataset/', name='sider')  # 27

embed_dim = 64
p_epoch = 5
n_folds = 2
f_epoch = 1

encoder = Encoder(gnn='schnet', energy_and_force=False, cutoff=10.0, num_layers=5, hidden_channels=embed_dim, num_filters=128, num_gaussians=50)
graphcl = GraphCL(embed_dim, aug_1='ETKDG1', aug_2='ETKDG2', tau=0.2)
evaluator = GraphUnsupervised(dataset_pretrain=freesolv, dataset=sider, out_dim=27, split='scaffold', log_interval=10, p_lr = 0.0005, p_epoch=p_epoch, device=torch.device('cuda:1'),
                             batch_size=32, f_lr=0.0001, n_folds=n_folds, f_epoch=f_epoch)
evaluator.grid_search(learning_model=graphcl, encoder=encoder, task_type='cls', 
                      p_lr_lst=[0.01, 0.001], p_epoch_lst=[1], f_lr_lst=[0.001], f_epoch_lst=[1])

In [3]:
import numpy as np
a = [0.8582720588235294, 0.9389868910826995, 0.9200989486703772]
np.mean(a)

0.9057859661922021

In [4]:
esol = MoleculeNet(root='dataset/', name='esol')  # regression
freesolv = MoleculeNet(root='dataset/', name='freesolv')  # regression
lipo = MoleculeNet(root='dataset/', name='lipo')  # regression
hiv = MoleculeNet(root='dataset/', name='hiv')    # binary
bace = MoleculeNet(root='dataset/', name='bace')  # binary
bbbp = MoleculeNet(root='dataset/', name='bbbp')  # binary
clintox = MoleculeNet(root='dataset/', name='clintox')  # binary
tox21 = MoleculeNet(root='dataset/', name='tox21') # 12
sider = MoleculeNet(root='dataset/', name='sider')  # 27
moleculenet = MoleculeNet(root='dataset/', name='moleculenet') 

embed_dim = 64
p_epoch = 5
n_folds = 4
f_epoch = 1

encoder = Encoder(gnn='schnet', energy_and_force=False, cutoff=10.0, num_layers=5, hidden_channels=embed_dim, num_filters=128, num_gaussians=50)
graphcl = GraphCL(embed_dim, aug_1='ETKDG1', aug_2='ETKDG2', tau=0.2)
evaluator = GraphUnsupervised(dataset_pretrain=moleculenet, dataset=tox21, out_dim=12, split='scaffold', log_interval=10, p_lr = 0.0005, p_epoch=p_epoch, device=torch.device('cuda:1'),
                             batch_size=32, f_lr=0.0001, n_folds=n_folds, f_epoch=f_epoch)
loss_m_tox21,loss_sd_tox21, paras_tox21 = evaluator.grid_search(learning_model=graphcl, encoder=encoder, task_type='cls', 
                                                                p_lr_lst=[0.005, 0.001, 0.00001], p_epoch_lst=[5, 10, 30, 50], f_lr_lst=[0.01, 0.001, 0.00001], f_epoch_lst=[5, 10, 30, 70])

ETKDG1
ETKDG2
proj_head in_dim 64
proj_head out_dim 64


Pretraining: epoch 2:  20%|██        | 1/5 [02:32<10:11, 152.94s/it, loss=nan]


KeyboardInterrupt: 