In [None]:
import os
import torch
os.environ["TORCH_VERSION"] = torch.__version__

!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-$TORCH_VERSION.html
!pip install torchdrug

Looking in links: https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 6.3 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Collecting torchdrug
  Downloading torchdrug-0.1.2.post1-py3-none-any.whl (191 kB)
[K     |████████████████████████████████| 191 kB 7.2 MB/s 
[?25hCollecting rdkit-pypi
  Downloading rdkit_pypi-2021.9.2.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.6 MB)
[K     |████████████████████████████████| 20.6 MB 1.4 MB/s 
Collecting ninja
  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (108 kB)
[K     |████████████████████████████████| 108 kB 71.4 MB/s 
Installing collected packages: rdkit-pypi, ninja, torchdrug
Successfully installed ninja-1.10.2.3 rdkit-pypi-2021.9.2.1 

In [None]:
from torchdrug import data,core, models, tasks
from torch import nn, optim
import pandas as pd
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###**Dataset**

In [None]:
path = '/content/drive/MyDrive/bio_project/zinc/zinc250k.csv'
df = pd.read_csv(path)
df.head()

Unnamed: 0,smiles,logP,qed,SAS
0,CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1\n,5.0506,0.702012,2.084095
1,C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1\n,3.1137,0.928975,3.432004
2,N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...,4.96778,0.599682,2.470633
3,CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...,4.00022,0.690944,2.822753
4,N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...,3.60956,0.789027,4.035182


###**cleaned dataset**

In [None]:
new_smiles = []
for smile in df['smiles']:
  new_smiles.append(smile.strip())

df['smiles'] = new_smiles
df.head()

Unnamed: 0,smiles,logP,qed,SAS
0,CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1,5.0506,0.702012,2.084095
1,C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1,3.1137,0.928975,3.432004
2,N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...,4.96778,0.599682,2.470633
3,CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...,4.00022,0.690944,2.822753
4,N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...,3.60956,0.789027,4.035182


In [None]:
zinc_dataset = data.MoleculeDataset()
zinc_dataset.load_csv(path, smiles_field='smiles',target_fields=['logP','QED'],kekulize=True,node_feature="symbol")

In [None]:
model = models.RGCN(input_dim=zinc_dataset.node_feature_dim,
                    num_relation=zinc_dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)

task = tasks.GCPNGeneration(model, zinc_dataset.atom_types, max_edge_unroll=12,
                            max_node=38, criterion="nll")


optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, zinc_dataset, None, None, optimizer,
                     gpus=(0,), batch_size=32, log_interval=10)

solver.train(num_epoch=5)
solver.save("/content/drive/MyDrive/bio_project/zinc/gcpn_zinc250k_5epoch.pkl")


06:41:48   Preprocess training set
06:41:49   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
06:41:49   Epoch 0 begin




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
08:25:27   stop acc: 0.829208
08:25:27   stop bce loss: 0.252383
08:25:27   total loss: 2.78583
08:25:29   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
08:25:29   edge acc: 0.890052
08:25:29   edge loss: 0.268825
08:25:29   node1 acc: 0.496073
08:25:29   node1 loss: 1.24177
08:25:29   node2 acc: 0.708115
08:25:29   node2 loss: 0.869832
08:25:29   stop acc: 0.85804
08:25:29   stop bce loss: 0.215232
08:25:29   total loss: 2.59566
08:25:31   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
08:25:31   edge acc: 0.897668
08:25:31   edge loss: 0.249891
08:25:31   node1 acc: 0.483161
08:25:31   node1 loss: 1.26225
08:25:31   node2 acc: 0.708549
08:25:31   node2 loss: 0.88286
08:25:31   stop acc: 0.845771
08:25:31   stop bce loss: 0.270361
08:25:31   total loss: 2.66536
08:25:32   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
08:25:33   edge acc: 0.896947
08:25:33   edge loss: 0.251022
08:25:33   node1 acc: 0.450382
08:25:33   node1 loss: 1.28832
08:25:33   node2 acc: 0.7162

###**Reinforcement Tuning**

In [None]:
model = models.RGCN(input_dim=zinc_dataset.node_feature_dim,
                    num_relation=zinc_dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)

task = tasks.GCPNGeneration(model, zinc_dataset.atom_types,
                            max_edge_unroll=12, max_node=38,
                            task=('qed','plogp'), criterion=('ppo', 'nll'),
                            reward_temperature=1,
                            agent_update_interval=3, gamma=0.9)

optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, zinc_dataset, None, None, optimizer,
                     gpus=(0,), batch_size=32, log_interval=10)

solver.load('/content/drive/MyDrive/bio_project/zinc/gcpn_zinc250k_5epoch.pkl',
            load_optimizer=False)

solver.train(num_epoch=10)
solver.save('/content/drive/MyDrive/bio_project/zinc/gcpn_zinc250k_5epoch_finetune.pkl')



13:49:10   Preprocess training set
13:49:11   Load checkpoint from /content/drive/MyDrive/bio_project/zinc/gcpn_zinc250k_5epoch.pkl
13:49:11   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:49:11   Epoch 0 begin




13:49:13   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:49:13   PPO objective: 2.02883
13:49:13   Penalized logP: -5.17918
13:49:13   Penalized logP (max): 0.895549
13:49:13   QED: 0.657484
13:49:13   QED (max): 0.852045
13:49:13   edge acc: 0.896552
13:49:13   edge loss: 0.25208
13:49:13   node1 acc: 0.496807
13:49:13   node1 loss: 1.31632
13:49:13   node2 acc: 0.744572
13:49:13   node2 loss: 0.831459
13:49:13   stop acc: 0.83681
13:49:13   stop bce loss: 0.23149
13:49:13   total loss: 2.63135
13:49:28   1 / 28 molecules are invalid even after 20 resampling
13:49:33   1 / 13 molecules are invalid even after 20 resampling
13:49:34   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:49:34   PPO objective: 1.44125
13:49:34   Penalized logP: -5.56221
13:49:34   Penalized logP (max): 2.26699
13:49:34   QED: 0.639046
13:49:34   QED (max): 0.84194
13:49:34   edge acc: 0.897135
13:49:34   edge loss: 0.250978
13:49:34   node1 acc: 0.485677
13:49:34   node1 loss: 1.29907
13:49:34   node2 acc: 0.735677
13:49:34   node2 lo

RuntimeError: ignored

###**Generate**

In [None]:
model = models.RGCN(input_dim=zinc_dataset.node_feature_dim,
                    num_relation=zinc_dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)

task = tasks.GCPNGeneration(model, zinc_dataset.atom_types, max_edge_unroll=12,
                            max_node=38, criterion="nll")


optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, zinc_dataset, None, None, optimizer,
                     gpus=(0,), batch_size=32, log_interval=10)

solver.save("/content/drive/MyDrive/bio_project/zinc/gcpn_zinc250k_5epoch.pkl")
results = task.generate(num_sample=100, max_resample=5)
all_smiles = results.to_smiles()

15:51:52   Preprocess training set
15:52:01   Save checkpoint to /content/drive/MyDrive/bio_project/zinc/gcpn_zinc250k_5epoch.pkl




15:52:21   1 / 100 molecules are invalid even after 5 resampling
15:52:22   4 / 99 molecules are invalid even after 5 resampling
15:52:22   8 / 95 molecules are invalid even after 5 resampling
15:52:22   8 / 87 molecules are invalid even after 5 resampling
15:52:22   7 / 79 molecules are invalid even after 5 resampling
15:52:22   3 / 55 molecules are invalid even after 5 resampling
15:52:23   2 / 23 molecules are invalid even after 5 resampling
15:52:23   1 / 8 molecules are invalid even after 5 resampling


###**Analyze the resule**

In [None]:
! wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh
! chmod +x Miniconda3-py37_4.8.2-Linux-x86_64.sh
! bash ./Miniconda3-py37_4.8.2-Linux-x86_64.sh -b -f -p /usr/local
! conda install -c rdkit rdkit -y
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

--2021-11-27 15:52:45--  https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.131.3, 104.16.130.3, 2606:4700::6810:8203, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.131.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 85055499 (81M) [application/x-sh]
Saving to: ‘Miniconda3-py37_4.8.2-Linux-x86_64.sh’


2021-11-27 15:52:46 (116 MB/s) - ‘Miniconda3-py37_4.8.2-Linux-x86_64.sh’ saved [85055499/85055499]

PREFIX=/usr/local
Unpacking payload ...
Collecting package metadata (current_repodata.json): - \ | done
Solving environment: - done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - _libgcc_mutex==0.1=main
    - asn1crypto==1.3.0=py37_0
    - ca-certificates==2020.1.1=0
    - certifi==2019.11.28=py37_0
    - cffi==1.14.0=py37h2e261b9_0
    - chardet==3.0.4=py37_1003
    - conda-package-handling==1.6.0=py37h7b6447c_0
    

In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski

In [None]:
def calculate_logp_qed(smiles):
  logP = []
  qed = []
  for smile in smiles : 
    mol = Chem.MolFromSmiles(smile)
    if mol != None : 
      logP.append(Descriptors.MolLogP(mol))
      qed.append(Chem.QED.weights_max(mol))
  return (logP,qed)

In [None]:
logP,qed = calculate_logp_qed(all_smiles)
data = {'smiles':all_smiles, 'logP':logP, 'qed': qed}
df2 = pd.DataFrame(data=data)

In [None]:
df2.head()

Unnamed: 0,smiles,logP,qed
0,CC=CC(C)C,2.2185,0.451964
1,CCC=C(C)C,2.3626,0.452347
2,C#CC=C=C=C,1.1158,0.291686
3,CC=C(C)CC,2.3626,0.452347
4,C#CC(C)=CC,1.5858,0.410933


In [None]:
df2.describe()

Unnamed: 0,logP,qed
count,66.0,66.0
mean,1.656964,0.405687
std,0.556307,0.060442
min,0.3585,0.291605
25%,1.2942,0.355993
50%,1.66415,0.410285
75%,2.13585,0.449146
max,2.612,0.526301


###**Export**

In [None]:
df2.to_csv(path+'zinc_output.csv')