In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
from tdc.multi_pred import DTI

In [2]:
from utils import TestbedDataset
from create_data import seq_cat, smile_to_graph
from interface import GraphDTA, Trainer

In [3]:
def get_smile_graph(list_of_smiles):
    list_of_smiles = list(set(list_of_smiles))
    return {s:smile_to_graph(s) for s in list_of_smiles}
       

def create_tdc_dataset(df, fold, smile_graph):
    ## params
    max_seqlen = 1000
    aaseq_dict = {v:(i+1) for i,v in enumerate("ABCDEFGHIKLMNOPQRSTUVWXYZ")}
    
    ## preprocess
    xd = df.loc[:,'Drug'].values
    xt = np.asarray([seq_cat(t, max_seqlen, aaseq_dict) for t in df.loc[:,'Target'].values])
    y  = df.loc[:,'Y'].values
    
    ## dataset
    return TestbedDataset(root='data', dataset=f'tdc_{fold}', xd=xd, xt=xt, y=y, smile_graph=smile_graph, use_tqdm=True)

In [4]:
dataset_name = 'tdc'
data = DTI(name = 'BindingDB_Kd')

Found local copy...
Loading...
Done!


In [5]:
data.convert_to_log(form = 'binding')
df_whole = data.harmonize_affinities(mode = 'max_affinity')

To log space...
The scale is converted to log scale, so we will take the maximum!
The original data has been updated!


In [6]:
smile_graph = get_smile_graph(df_whole.loc[:,'Drug'].values)

In [7]:
frames = data.get_split(method='cold_drug', seed=2023, frac=[0.8, 0.1, 0.1])
df_tr = frames['train']
df_va = frames['valid']
df_te = frames['test']

In [8]:
df_te

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,311.0,O=C(O)CC(O)(CC(=O)O)C(=O)O,P15474,MPRSLANAPIMILNGPNLNLLGQRQPEIYGSDTLADVEALCVKAAA...,5.136671
1,311.0,O=C(O)CC(O)(CC(=O)O)C(=O)O,Q48255,MKILVIQGPNLNMLGHRDPRLYGMVTLDQIHEIMQTFVKQGNLDVE...,5.602043
2,1110.0,O=C(O)CCC(=O)O,Q9GZT9,MANDSGGPGGPSPSERDRQYCELCGKMENLLRCSRCRSSFYCCKEH...,4.000000
3,1110.0,O=C(O)CCC(=O)O,Q9Z429,MQQFTIRTRLLMLVGAMFIGFITIELMGFSALQRGVASLNTVYLDR...,4.086186
4,1688.0,CN1C(=O)CN=C(c2ccc(Cl)cc2)c2cc(Cl)ccc21,P16257,MSQSWVPAVGLTLVPSLGGFMGAYFVRGEGLRWYASLQKPSWHPPR...,7.235824
...,...,...,...,...,...
3388,138805877.0,COc1nc2ccc([C@@](O)(c3ccc(C#N)cc3)c3cncn3C)cc2...,P51449,MDRAPQRQHRASRELLAAKKTHTSQIEVIPCKICGDKSSGIHYGVI...,7.186419
3389,138805881.0,COc1nc2ccc([C@@](O)(c3ccc(C#N)cc3)c3cncn3C)cc2...,P51449,MDRAPQRQHRASRELLAAKKTHTSQIEVIPCKICGDKSSGIHYGVI...,7.675718
3390,138805882.0,COc1nc2ccc([C@](O)(c3ccc(Cl)cc3)c3cncn3C)cc2c(...,P51449,MDRAPQRQHRASRELLAAKKTHTSQIEVIPCKICGDKSSGIHYGVI...,5.602043
3391,138805884.0,COc1nc2ccc([C@](O)(c3ccnc(C(F)(F)F)c3)c3cncn3C...,P51449,MDRAPQRQHRASRELLAAKKTHTSQIEVIPCKICGDKSSGIHYGVI...,5.886023


In [9]:
train_data = create_tdc_dataset(df_tr, 'train', smile_graph)

Pre-processed data found: data/processed/tdc_train.pt, loading ...


In [10]:
test_data = create_tdc_dataset(df_te, 'test', smile_graph)

Pre-processed data found: data/processed/tdc_test.pt, loading ...


In [11]:
valid_data = create_tdc_dataset(df_va, 'valid', smile_graph)

Pre-processed data found: data/processed/tdc_valid.pt, loading ...


In [12]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else "cpu")

In [13]:
modeling = ['GINConvNet', 'GATNet', 'GAT_GCN', 'GCNNet'][0]

model = GraphDTA(modeling, device=device)

In [14]:
trainer = Trainer(model.model, model.device)

In [None]:
trainer.train(train_data, test_data, valid_data,
              ckpt_dir='./ckpt', ckpt_filename=f'{modeling}_tdc.pt')

[INFO] len(train_data):36265
[INFO] len(valid_data):2578
[INFO] len(test_data):3393
Training on 36265 samples...
[INFO] train epoch: 000001 (00%)	Loss: 30.405361
[INFO] train epoch: 000001 (28%)	Loss: 2.205092
[INFO] train epoch: 000001 (56%)	Loss: 1.657604
[INFO] train epoch: 000001 (84%)	Loss: 1.395095
Make prediction for 2578 samples...
Make prediction for 3393 samples...
[INFO] rmse improved at epoch 1; best_test_mse=1.606; best_test_ci=0.684
Training on 36265 samples...
[INFO] train epoch: 000002 (00%)	Loss: 1.257134
[INFO] train epoch: 000002 (28%)	Loss: 1.218627
[INFO] train epoch: 000002 (56%)	Loss: 1.189084
[INFO] train epoch: 000002 (84%)	Loss: 1.109169
Make prediction for 2578 samples...
Make prediction for 3393 samples...
[INFO] rmse improved at epoch 2; best_test_mse=1.216; best_test_ci=0.738
Training on 36265 samples...
[INFO] train epoch: 000003 (00%)	Loss: 1.140674
[INFO] train epoch: 000003 (28%)	Loss: 1.062391
[INFO] train epoch: 000003 (56%)	Loss: 0.989916
[INFO] tra