<a href="https://colab.research.google.com/github/goga0001/graph/blob/main/minimal%20repro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Manual Steps**

---




> Set the runtime to GPU via "Runtime > Change runtime type..."







# **Clone Git repository**

In [1]:
!git clone https://github.com/DeepGraphLearning/torchdrug

Cloning into 'torchdrug'...
remote: Enumerating objects: 1231, done.[K
remote: Counting objects: 100% (601/601), done.[K
remote: Compressing objects: 100% (265/265), done.[K
remote: Total 1231 (delta 363), reused 486 (delta 321), pack-reused 630[K
Receiving objects: 100% (1231/1231), 2.62 MiB | 27.42 MiB/s, done.
Resolving deltas: 100% (629/629), done.


In [2]:
cd torchdrug

/content/torchdrug


# **Install requirements**

In [None]:
!pip install decorator  rdkit rdkit-pypi torch_geometric==2.0.4 matplotlib  tqdm  networkx   jinja2  lmdb  fair-esm 

In [None]:
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.12.1+cu113.html
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.12.1+cu113.html
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.12.1+cu113.html
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.12.1+cu113.html

In [6]:
pip install fair-esm==0.1.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fair-esm==0.1.0
  Downloading fair_esm-0.1.0-py3-none-any.whl (23 kB)
Installing collected packages: fair-esm
  Attempting uninstall: fair-esm
    Found existing installation: fair-esm 2.0.0
    Uninstalling fair-esm-2.0.0:
      Successfully uninstalled fair-esm-2.0.0
Successfully installed fair-esm-0.1.0


In [5]:
!pip install Ninja==1.10.2.4

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting Ninja==1.10.2.4
  Downloading ninja-1.10.2.4-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (120 kB)
[K     |████████████████████████████████| 120 kB 30.6 MB/s 
[?25hInstalling collected packages: Ninja
Successfully installed Ninja-1.10.2.4


In [7]:
pip install torch torchvision

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# **Run setup.py**

In [None]:
!python setup.py install

# **Prepare the Pretraining Dataset(5 min)**

 Before registering new dataset- upload csv(https://raw.githubusercontent.com/goga0001/graph/main/data.csv) into dataset folder

In [9]:
import os

from torchdrug import data, utils
from torchdrug.core import Registry as R
from collections import defaultdict

from torch.utils import data as torch_data

from torchdrug import data
from torchdrug.utils import doc


@R.register("datasets.WB6")
@utils.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields"))
class WB6(data.MoleculeDataset):
    """
    Subset of random compound database for virtual screening.

    Statistics:
        - #Molecule:  4806
        - #Regression task: 2

    Parameters:
        path (str): path to store the dataset
        verbose (int, optional): output verbose level
        **kwargs
    """

    csv_file = "/content/torchdrug/torchdrug/datasets/data.csv"
    target_fields = ["logP","qed"]

    def __init__(self, path, verbose=1, **kwargs ):
        self.load_csv(self.csv_file, smiles_field="smiles", target_fields=self.target_fields,
                      verbose=verbose, **kwargs)

In [17]:
cd datasets/ 

/content/torchdrug/torchdrug/datasets


In [18]:
import torch
from torchdrug import core, models, tasks


dataset = WB6("/content/torchdrug/torchdrug/datasets/data.csv", kekulize=True,
                            atom_feature="symbol")






Loading /content/torchdrug/torchdrug/datasets/data.csv: 4807it [00:00, 59662.42it/s]            
Constructing molecules from SMILES: 100%|██████████| 4806/4806 [00:08<00:00, 538.50it/s]


# **Define the Model: GraphAF (30sec)**

In [19]:
from torchdrug import core, models, tasks
from torchdrug.layers import distribution

model = models.RGCN(input_dim=dataset.num_atom_type,
                    num_relation=dataset.num_bond_type,
                    hidden_dims=[256, 256, 256], batch_norm=True)

num_atom_type = dataset.num_atom_type
# add one class for non-edge
num_bond_type = dataset.num_bond_type + 1

node_prior = distribution.IndependentGaussian(torch.zeros(num_atom_type),
                                              torch.ones(num_atom_type))
edge_prior = distribution.IndependentGaussian(torch.zeros(num_bond_type),
                                              torch.ones(num_bond_type))
node_flow = models.GraphAF(model, node_prior, num_layer=12)
edge_flow = models.GraphAF(model, edge_prior, use_edge=True, num_layer=12)

task = tasks.AutoregressiveGeneration(node_flow, edge_flow,
                                      max_node=38, max_edge_unroll=12,
                                      criterion="nll")

In [20]:
dataset.num_bond_type

2

In [21]:
model.layers[0].num_relation

2

In [23]:
dataset[2]["graph"].num_relation

for data in dataset:
  data['graph'].num_relation = torch.tensor(2)

In [24]:
dataset[2]["graph"].num_relation

tensor(2)

# **Pretraining and Generation: GraphAF**

In [28]:
from torch import nn, optim
optimizer = optim.Adam(task.parameters(), lr = 1e-3)
solver = core.Engine(task, dataset, None, None, optimizer,
                     gpus=(0,), batch_size=128, log_interval=10)

solver.train(num_epoch=1)
solver.save("graphaf_WB.pkl")



04:39:59   Preprocess training set


 'class': 'core.Engine',
 'gpus': (0,),
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 5,
          'baseline_momentum': 0.9,
          'class': 'tasks.AutoregressiveGeneration',
          'criterion': 'nll',
          'edge_model': {'class': 'models.GraphAF',
                         'dequantization_noise': 0.9,
                         'model': {'activation': 'relu',
                                   'batch_norm': True,
                                   'class': 'models.RGCN',
                                   'concat_hidden': False,
                                   'edge_input_dim': None,

04:39:59   {'batch_size': 128,
 'class': 'core.Engine',
 'gpus': (0,),
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'eps': 1e-08,
               'foreach': None,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 5,
          'baseline_momentum': 0.9,
          'class': 'tasks.AutoregressiveGeneration',
          'criterion': 'nll',
          'edge_model': {'class': 'models.GraphAF',
                         'dequantization_noise': 0.9,
                         'model': {'activation': 'relu',
                                   'batch_norm': True,
                                   'class': 'models.RGCN',
                                   'concat_hidden': False,
                           



04:39:59   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>




04:39:59   Epoch 0 begin


AssertionError: ignored

In [27]:
 
torch.cuda.empty_cache()

In [None]:
from collections import defaultdict

solver.load("graphaf_WB.pkl")
results = task.generate(num_sample=32)
print(results.to_smiles())



17:17:52   Load checkpoint from graphaf_flavonoid_1epoch.pkl




['NCl', 'P', 'Br', 'ClP1(Cl)=P(P(Cl)(Cl)(Cl)P2=P3=S=2(Cl)S3(Cl)(Cl)(Cl)P(Cl)(Cl)(Cl)Cl)=P1', 'Cl', 'O', 'Cl[SH]12(OBr)P=P13PN32', 'P=[PH]1P23=P4(Cl)P5(Cl)=P6=S5(Cl)(Br)S612=P43Cl', 'ClP12(Cl)N=N[PH]1=N2', 'S', 'ClSBr', 'O', 'P', 'CBr', 'Cl', 'CCl', 'S', 'NBr', 'Cl', 'P', 'ClBr', 'PBr', 'ClCl', 'N', 'OCCl', 'N', 'P=S', 'Cl', 'S', 'SCl', 'O', 'C=O']
