In [None]:
import os
import torch
os.environ["TORCH_VERSION"] = torch.__version__

!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-$TORCH_VERSION.html
!pip install torchdrug

Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 4.3 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Collecting torchdrug
  Downloading torchdrug-0.1.2.post1-py3-none-any.whl (191 kB)
[K     |████████████████████████████████| 191 kB 16.4 MB/s 
Collecting ninja
  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)
[K     |████████████████████████████████| 108 kB 71.2 MB/s 
Collecting rdkit-pypi
  Downloading rdkit_pypi-2021.9.2.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.6 MB)
[K     |████████████████████████████████| 20.6 MB 1.2 MB/s 
Installing collected packages: rdkit-pypi, ninja, torchdrug
Successfully installed ninja-1.10.2.3 rdkit-pypi-2021.9.2.1 torch

In [None]:
from torchdrug import data,core, models, tasks
from torch import nn, optim
import pandas as pd
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###**Load Dataset**

In [None]:
path = '/content/drive/MyDrive/bio_project/chembl/'
df = pd.read_csv(path+'chembl_data.csv')
df.head()

Unnamed: 0.1,Unnamed: 0,assay_chembl_id,smiles,logP,QED
0,0,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,Cc1noc(C)c1CN1C(=O)C(=O)c2cc(C#N)ccc21,1.89262,0.757559
1,1,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(Cc2ccc(F)cc2Cl)c2ccc(I)cc21,3.8132,0.487042
2,2,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(CC2COc3ccccc3O2)c2ccc(I)cc21,2.6605,0.485762
3,3,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(Cc2cc3ccccc3s2)c2ccccc21,3.6308,0.683944
4,4,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(Cc2cc3ccccc3s2)c2c1cccc2[N+](=O)[O-],3.539,0.348717


###**Cleaned Dataset**

In [None]:
new_smiles = []
for smile in df['smiles']:
  new_smiles.append(smile.strip())
df['smiles'] = new_smiles
df.head()

Unnamed: 0.1,Unnamed: 0,assay_chembl_id,smiles,logP,QED
0,0,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,Cc1noc(C)c1CN1C(=O)C(=O)c2cc(C#N)ccc21,1.89262,0.757559
1,1,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(Cc2ccc(F)cc2Cl)c2ccc(I)cc21,3.8132,0.487042
2,2,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(CC2COc3ccccc3O2)c2ccc(I)cc21,2.6605,0.485762
3,3,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(Cc2cc3ccccc3s2)c2ccccc21,3.6308,0.683944
4,4,0 CHEMBL829584\n1 CHEMBL829584\n2 ...,O=C1C(=O)N(Cc2cc3ccccc3s2)c2c1cccc2[N+](=O)[O-],3.539,0.348717


In [None]:
chembl_dataset = data.MoleculeDataset()
chembl_dataset.load_csv(path+'chembl_data.csv', smiles_field='smiles',target_fields=['logP','QED'],kekulize=True,node_feature="symbol")

###**Defined Model**

In [None]:
model = models.RGCN(input_dim=chembl_dataset.node_feature_dim,
                    num_relation=chembl_dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)

task = tasks.GCPNGeneration(model, chembl_dataset.atom_types, max_edge_unroll=12,
                            max_node=38, criterion="nll")


optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, chembl_dataset, None, None, optimizer,
                     gpus=(0,), batch_size=32, log_interval=1)

solver.train(num_epoch=150)
solver.save(path+'gcpn_chembl_150epoch.pkl')

09:47:04   Preprocess training set
09:47:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:47:04   Epoch 0 begin
09:47:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:47:05   edge acc: 0.0355913
09:47:05   edge loss: 1.10787
09:47:05   node1 acc: 0.126292
09:47:05   node1 loss: 2.35102
09:47:05   node2 acc: 0.00344432
09:47:05   node2 loss: 2.9349
09:47:05   stop acc: 0.0376523
09:47:05   stop bce loss: 0.688784
09:47:05   total loss: 7.08257




09:47:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:47:05   edge acc: 0.0313901
09:47:05   edge loss: 1.10725
09:47:05   node1 acc: 0.122197
09:47:05   node1 loss: 2.36882
09:47:05   node2 acc: 0.0100897
09:47:05   node2 loss: 2.94546
09:47:05   stop acc: 0.034632
09:47:05   stop bce loss: 0.684954
09:47:05   total loss: 7.10648
09:47:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:47:05   edge acc: 0.0328054
09:47:05   edge loss: 1.10653
09:47:05   node1 acc: 0.125566
09:47:05   node1 loss: 2.36865
09:47:05   node2 acc: 0.0961538
09:47:05   node2 loss: 2.94493
09:47:05   stop acc: 0.0349345
09:47:05   stop bce loss: 0.682208
09:47:05   total loss: 7.10231
09:47:05   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:47:05   edge acc: 0.0425777
09:47:05   edge loss: 1.10524
09:47:05   node1 acc: 0.149597
09:47:05   node1 loss: 2.35469
09:47:05   node2 acc: 0.201381
09:47:05   node2 loss: 2.93584
09:47:05   stop acc: 0.0355161
09:47:05   stop bce loss: 0.67949
09:47:05   total loss: 7.07527
09:47:05   >>>>>>>>>>>>>>>>>



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
09:48:28   edge acc: 0.681447
09:48:28   edge loss: 0.529962
09:48:28   node1 acc: 0.193699
09:48:28   node1 loss: 2.08352
09:48:28   node2 acc: 0.644107
09:48:28   node2 loss: 2.15957
09:48:28   stop acc: 0.707537
09:48:28   stop bce loss: 0.458346
09:48:28   total loss: 5.2314
09:48:28   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:48:28   edge acc: 0.659722
09:48:28   edge loss: 0.52063
09:48:28   node1 acc: 0.166667
09:48:28   node1 loss: 2.13803
09:48:28   node2 acc: 0.631944
09:48:28   node2 loss: 2.16833
09:48:28   stop acc: 0.66443
09:48:28   stop bce loss: 0.414347
09:48:28   total loss: 5.24134
09:48:28   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:48:28   Epoch 76 end
09:48:28   duration: 1.17 secs
09:48:28   speed: 4.26 batch / sec
09:48:28   ETA: 1.32 mins
09:48:28   max GPU memory: 532.7 MiB
09:48:28   ------------------------------
09:48:28   average edge acc: 0.676718
09:48:28   average edge loss: 0.530796
09:48:28   average n

In [None]:
solver.train(num_epoch=300)

09:50:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:50:56   Epoch 150 begin
09:50:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:50:56   edge acc: 0.811834
09:50:56   edge loss: 0.435139
09:50:56   node1 acc: 0.313609
09:50:56   node1 loss: 1.85555
09:50:56   node2 acc: 0.64497
09:50:56   node2 loss: 1.83054
09:50:56   stop acc: 0.800456
09:50:56   stop bce loss: 0.392518
09:50:56   total loss: 4.51374




09:50:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:50:56   edge acc: 0.824683
09:50:56   edge loss: 0.434765
09:50:56   node1 acc: 0.33218
09:50:56   node1 loss: 1.90245
09:50:56   node2 acc: 0.61015
09:50:56   node2 loss: 1.85946
09:50:56   stop acc: 0.839822
09:50:56   stop bce loss: 0.414648
09:50:56   total loss: 4.61133
09:50:57   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:50:57   edge acc: 0.839326
09:50:57   edge loss: 0.452597
09:50:57   node1 acc: 0.292135
09:50:57   node1 loss: 1.91264
09:50:57   node2 acc: 0.644944
09:50:57   node2 loss: 1.8036
09:50:57   stop acc: 0.850325
09:50:57   stop bce loss: 0.396093
09:50:57   total loss: 4.56494
09:50:57   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:50:57   edge acc: 0.843956
09:50:57   edge loss: 0.426651
09:50:57   node1 acc: 0.314286
09:50:57   node1 loss: 1.92846
09:50:57   node2 acc: 0.61978
09:50:57   node2 loss: 1.83034
09:50:57   stop acc: 0.823779
09:50:57   stop bce loss: 0.399079
09:50:57   total loss: 4.58453
09:50:57   >>>>>>>>>>>>>>>>>>>>>>>>



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
09:55:14   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:55:14   edge acc: 0.838323
09:55:14   edge loss: 0.392142
09:55:14   node1 acc: 0.343713
09:55:14   node1 loss: 1.73431
09:55:14   node2 acc: 0.632335
09:55:14   node2 loss: 1.31757
09:55:14   stop acc: 0.866205
09:55:14   stop bce loss: 0.280117
09:55:14   total loss: 3.72414
09:55:14   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:55:14   edge acc: 0.806452
09:55:14   edge loss: 0.404472
09:55:14   node1 acc: 0.387097
09:55:14   node1 loss: 1.8218
09:55:14   node2 acc: 0.670968
09:55:14   node2 loss: 1.21171
09:55:14   stop acc: 0.75
09:55:14   stop bce loss: 0.301395
09:55:14   total loss: 3.73938
09:55:14   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
09:55:14   Epoch 376 end
09:55:14   duration: 1.16 secs
09:55:14   speed: 4.31 batch / sec
09:55:14   ETA: 1.72 mins
09:55:14   max GPU memory: 524.1 MiB
09:55:14   ------------------------------
09:55:14   average edge acc: 0.83222
09:55:14   average 

In [None]:
solver.save(path+'gcpn_chembl_450epoch.pkl')

09:57:55   Save checkpoint to /content/drive/MyDrive/bio_project/chembl/gcpn_chembl_450epoch.pkl


###**Finetuning**

In [None]:
model = models.RGCN(input_dim=chembl_dataset.node_feature_dim,
                    num_relation=chembl_dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)

task = tasks.GCPNGeneration(model, chembl_dataset.atom_types,
                            max_edge_unroll=12, max_node=38,
                            task=('qed','plogp'), criterion=('ppo', 'nll'),
                            reward_temperature=1,
                            agent_update_interval=3, gamma=0.9)

optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, chembl_dataset, None, None, optimizer,
                     gpus=(0,), batch_size=32, log_interval=10)

solver.load(path+'gcpn_chembl_450epoch.pkl',
            load_optimizer=False)

# RL
solver.train(num_epoch=10)
solver.save(path+'gcpn_zinc250k_10epoch_finetune.pkl')

13:18:34   Preprocess training set
13:18:34   Load checkpoint from /content/drive/MyDrive/bio_project/chembl/gcpn_chembl_450epoch.pkl
13:18:34   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:18:34   Epoch 0 begin




13:18:38   Downloading https://github.com/rdkit/rdkit/raw/master/Contrib/SA_Score/fpscores.pkl.gz to /usr/local/lib/python3.7/dist-packages/torchdrug/metrics/rdkit/fpscores.pkl.gz
13:18:39   Extracting /usr/local/lib/python3.7/dist-packages/torchdrug/metrics/rdkit/fpscores.pkl.gz to /usr/local/lib/python3.7/dist-packages/torchdrug/metrics/rdkit/fpscores.pkl
13:18:40   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:18:40   PPO objective: 2.19318
13:18:40   Penalized logP: -6.24947
13:18:40   Penalized logP (max): 2.1055
13:18:40   QED: 0.510817
13:18:40   QED (max): 0.732614
13:18:40   edge acc: 0.835821
13:18:40   edge loss: 0.379763
13:18:40   node1 acc: 0.34558
13:18:40   node1 loss: 1.73914
13:18:40   node2 acc: 0.670494
13:18:40   node2 loss: 1.23381
13:18:40   stop acc: 0.766334
13:18:40   stop bce loss: 0.298387
13:18:40   total loss: 3.6511
13:18:56   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:18:56   Epoch 0 end
13:18:56   duration: 22.16 secs
13:18:56   speed: 0.23 batch / sec
13:18:56   ETA: 3.32 



13:21:00   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:21:00   Epoch 7 end
13:21:00   duration: 14.89 secs
13:21:00   speed: 0.34 batch / sec
13:21:00   ETA: 36.56 secs
13:21:00   max GPU memory: 607.9 MiB
13:21:00   ------------------------------
13:21:00   average PPO objective: 1.16821
13:21:00   average Penalized logP: -3.22437
13:21:00   average Penalized logP (max): 1.52571
13:21:00   average QED: 0.520951
13:21:00   average QED (max): 0.725005
13:21:00   average edge acc: 0.828629
13:21:00   average edge loss: 0.381151
13:21:00   average node1 acc: 0.322784
13:21:00   average node1 loss: 1.74474
13:21:00   average node2 acc: 0.668132
13:21:00   average node2 loss: 1.2649
13:21:00   average stop acc: 0.808661
13:21:00   average stop bce loss: 0.26875
13:21:00   average total loss: 3.65954
13:21:00   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:21:00   Epoch 8 begin
13:21:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:21:04   PPO objective: 0.750417
13:21:04   Penalized logP: -6.59524
13:21:04   Penalized log

###**Generate**

In [None]:
model = models.RGCN(input_dim=chembl_dataset.node_feature_dim,
                    num_relation=chembl_dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)

task = tasks.GCPNGeneration(model, chembl_dataset.atom_types,
                            max_edge_unroll=12, max_node=38,
                            task=('qed','plogp'), criterion=('ppo', 'nll'),
                            reward_temperature=1,
                            agent_update_interval=3, gamma=0.9)

optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, chembl_dataset, None, None, optimizer,
                     gpus=(0,), batch_size=32, log_interval=10)

solver.load(path+'gcpn_zinc250k_10epoch_finetune.pkl')
results = task.generate(num_sample=100, max_resample=5)
all_smiles = results.to_smiles()

16:02:28   Preprocess training set
16:02:36   Load checkpoint from /content/drive/MyDrive/bio_project/chembl/gcpn_zinc250k_10epoch_finetune.pkl




###**Analyze the result**

In [None]:
! wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh
! chmod +x Miniconda3-py37_4.8.2-Linux-x86_64.sh
! bash ./Miniconda3-py37_4.8.2-Linux-x86_64.sh -b -f -p /usr/local
! conda install -c rdkit rdkit -y
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

--2021-11-27 16:03:22--  https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.131.3, 104.16.130.3, 2606:4700::6810:8303, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.131.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 85055499 (81M) [application/x-sh]
Saving to: ‘Miniconda3-py37_4.8.2-Linux-x86_64.sh’


2021-11-27 16:03:24 (72.8 MB/s) - ‘Miniconda3-py37_4.8.2-Linux-x86_64.sh’ saved [85055499/85055499]

PREFIX=/usr/local
Unpacking payload ...
Collecting package metadata (current_repodata.json): - \ done
Solving environment: / - done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - _libgcc_mutex==0.1=main
    - asn1crypto==1.3.0=py37_0
    - ca-certificates==2020.1.1=0
    - certifi==2019.11.28=py37_0
    - cffi==1.14.0=py37h2e261b9_0
    - chardet==3.0.4=py37_1003
    - conda-package-handling==1.6.0=py37h7b6447c_0
   

In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski

In [None]:
def calculate_logp_qed(smiles):
  logP = []
  qed = []
  for smile in smiles : 
    mol = Chem.MolFromSmiles(smile)
    if mol != None : 
      logP.append(Descriptors.MolLogP(mol))
      qed.append(Chem.QED.weights_max(mol))
  return (logP,qed)

In [None]:
logP,qed = calculate_logp_qed(all_smiles)
data = {'smiles':all_smiles, 'logP':logP, 'qed': qed}
df2 = pd.DataFrame(data=data)

In [None]:
df2.head()

Unnamed: 0,smiles,logP,qed
0,C=C(C)C(C)C,2.2185,0.451964
1,CCCC(C)C,2.4425,0.524779
2,CC(C)C(C)C,2.2984,0.49783
3,CCC(S)CC,2.1048,0.542195
4,CC1=CC=C1C,1.8926,0.45225


In [None]:
df2.describe()

Unnamed: 0,logP,qed
count,100.0,100.0
mean,4.178036,0.559319
std,1.710954,0.102879
min,1.7519,0.190986
25%,2.856375,0.506744
50%,3.7448,0.566942
75%,5.384425,0.625855
max,9.2713,0.794572


###**Export**

In [None]:
df2.to_csv(path+'chembl_output.csv')