In [1]:
import torch
import pickle
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from deepdtagen.demo.demo_utils import *
from deepdtagen.demo.model_aff import DeepDTAGen

# 1. Environments

In [2]:
dataset_name = 'bindingdb'

## Setup device

In [3]:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

## Paths

In [4]:
filepath_model = os.path.join('models', f'deepdtagen_model_{dataset_name}.pth')
filepath_tokenizer = os.path.join('data', f'{dataset_name}_tokenizer.pkl')

## Load Tokenizer

In [5]:
with open(filepath_tokenizer, 'rb') as f:
    tokenizer = pickle.load(f)

In [6]:
print(len(tokenizer))

107


## Load Model

In [7]:
model = DeepDTAGen(tokenizer)

In [8]:
model.load_state_dict(torch.load(filepath_model, map_location=device))

<All keys matched successfully>

In [9]:
_ = model.to(device)

In [10]:
_ = model.eval()

## Test Data

In [11]:
seqs_smi = [
    "CC1=CC[C@@H]2[C@@H](C1)C3=C(C=C(C=C3OC2(C)C)C(C)(C)CCCCCCBr)O",
    "CCCCCCC(C)(C)C1=CC(=C2[C@@H]3C[C@@H](CC[C@H]3[C@](OC2=C1)(C)/C=C/CO)CO)O",
    "CCCCCN1C=C(C2=CC=CC=C21)C(=O)C3C(C3(C)C)(C)C",
    "CC1=C(C2=C3N1[C@@H](COC3=CC=C2)CN4CCOCC4)C(=O)C5=CC=CC6=CC=CC=C65",
    "CCCCCC1=CC(=C2[C@@H]3C=C(CC[C@H]3C(OC2=C1)(C)C)C)O",
    "C[C@@H]([C@@H](CC1=CC=C(C=C1)Cl)C2=CC=CC(=C2)C#N)NC(=O)C(C)(C)OC3=NC=C(C=C3)C(F)(F)F",
    "CC1=C(N(N=C1C(=O)NN2CCCCC2)C3=C(C=C(C=C3)Cl)Cl)C4=CC=C(C=C4)I",
    "CCC1=C(N(N=C1C(=O)NN2CCCCC2)C3=C(C=C(C=C3)Cl)Cl)C4=CC=C(C=C4)Br",
    "CC1=C(C2=C(N1CCN3CCOCC3)C=C(C=C2)I)C(=O)C4=CC=C(C=C4)OC",
    "CC1=CC=C(C=C1)CN2C(=CC(=N2)C(=O)N[C@H]3[C@]4(CC[C@H](C4)C3(C)C)C)C5=CC(=C(C=C5)Cl)C",
    "CC1=C(N(N=C1C(=O)NC23CC4CC(C2)CC(C4)C3)CCCCCO)C5=CC=CC=C5",
    "CCCCCC1=CC(=C(C(=C1)O)[C@@H]2C=C(CC[C@H]2C(=C)C)C)O",
]

In [12]:
seqs_prot = [
    "MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQKFPLTSFRGSPFQEKMTAGDNPQLVPADQVNITEFYNKSLSSFKENEENIQCGENFMDIECFMVLNPSQQLAIAVLSLTLGTFTVLENLLVLCVILHSRSLRCRPSYHFIGSLAVADLLGSVIFVYSFIDFHVFHRKDSRNVFLFKLGGVTASFTASVGSLFLTAIDRYISIHRPLAYKRIVTRPKAVVAFCLMWTIAIVIAVLPLLGWNCEKLQSVCSDIFPHIDETYLMFWIGVTSVLLLFIVYAYMYILWKAHSHAVRMIQRGTQKSIIIHTSEDGKVQVTRPDQARMDIRLAKTLVLILVVLIICWGPLLAIMVYDVFGKMNKLIKTVFAFCSMLCLLNSTVNPIIYALRSKDLRHAFRSMFPSCEGTAQPLDNSMGDSDCLHKHANNAASVHRAAESCIKSTVKIAKVTMSVSTDTSAEAL",
    "MEECWVTEIANGSKDGLDSNPMKDYMILSGPQKTAVAVLCTLLGLLSALENVAVLYLILSSHQLRRKPSYLFIGSLAGADFLASVVFACSFVNFHVFHGVDSKAVFLLKIGSVTMTFTASVGSLLLTAIDRYLCLRYPPSYKALLTRGRALVTLGIMWVLSALVSYLPLMGWTCCPRPCSELFPLIPNDYLLSWLLFIAFLFSGIIYTYGHVLWKAHQHVASLSGHQDRQVPGMARMRLDVRLAKTLGLVLAVLLICWFPVLALMAHSLATTLSDQVKKAFAFCSMLCLINSMVNPVIYALRSGEIRSSAHHCLAHWKKCVRGLGSEAKEEAPRSSVTETEADGKITPWPDSRDLDLSDC",    
]

In [13]:
df_inputs = pd.DataFrame(
    [{'SMILES':x, 'TARGET':y} for y in seqs_prot for x in seqs_smi]
)

In [14]:
df_inputs

Unnamed: 0,SMILES,TARGET
0,CC1=CC[C@@H]2[C@@H](C1)C3=C(C=C(C=C3OC2(C)C)C(...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
1,CCCCCCC(C)(C)C1=CC(=C2[C@@H]3C[C@@H](CC[C@H]3[...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
2,CCCCCN1C=C(C2=CC=CC=C21)C(=O)C3C(C3(C)C)(C)C,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
3,CC1=C(C2=C3N1[C@@H](COC3=CC=C2)CN4CCOCC4)C(=O)...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
4,CCCCCC1=CC(=C2[C@@H]3C=C(CC[C@H]3C(OC2=C1)(C)C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
5,C[C@@H]([C@@H](CC1=CC=C(C=C1)Cl)C2=CC=CC(=C2)C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
6,CC1=C(N(N=C1C(=O)NN2CCCCC2)C3=C(C=C(C=C3)Cl)Cl...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
7,CCC1=C(N(N=C1C(=O)NN2CCCCC2)C3=C(C=C(C=C3)Cl)C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
8,CC1=C(C2=C(N1CCN3CCOCC3)C=C(C=C2)I)C(=O)C4=CC=...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...
9,CC1=CC=C(C=C1)CN2C(=CC(=N2)C(=O)N[C@H]3[C@]4(C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...


# 2. Binding Affinity Prediction

## Data Loader

In [15]:
def create_dataset(seqs_smi, seqs_prot):
    smile_graph = {}
    for smi in seqs_smi:
        if not smi in smile_graph:
            g = smile_to_graph(smi)
            smile_graph[smi] = g
        
    XD = np.asarray(seqs_smi)
    XT = np.asarray([seq_cat(aa) for aa in seqs_prot])
    
    name = 'tmp'
    data = TestbedDataset(
        root='data',
        dataset=name,
        xd=XD,
        xt=XT,
        smile_graph=smile_graph
    )
    return data

In [16]:
test_data = create_dataset(df_inputs['SMILES'], df_inputs['TARGET'])

Preparing data in Pytorch Format: 1/24
Preparing data in Pytorch Format: 2/24
Preparing data in Pytorch Format: 3/24
Preparing data in Pytorch Format: 4/24
Preparing data in Pytorch Format: 5/24
Preparing data in Pytorch Format: 6/24
Preparing data in Pytorch Format: 7/24
Preparing data in Pytorch Format: 8/24
Preparing data in Pytorch Format: 9/24
Preparing data in Pytorch Format: 10/24
Preparing data in Pytorch Format: 11/24
Preparing data in Pytorch Format: 12/24
Preparing data in Pytorch Format: 13/24
Preparing data in Pytorch Format: 14/24
Preparing data in Pytorch Format: 15/24
Preparing data in Pytorch Format: 16/24
Preparing data in Pytorch Format: 17/24
Preparing data in Pytorch Format: 18/24
Preparing data in Pytorch Format: 19/24
Preparing data in Pytorch Format: 20/24
Preparing data in Pytorch Format: 21/24
Preparing data in Pytorch Format: 22/24
Preparing data in Pytorch Format: 23/24
Preparing data in Pytorch Format: 24/24


In [17]:
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=1,
    shuffle=False,
    collate_fn=collate
)

## Evaluate the model

In [18]:
predictions = []

with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
        y = model(data.to(device)).item()
        predictions.append(y)

Testing:   0%|          | 0/24 [00:00<?, ?it/s]

In [19]:
df_res = df_inputs.copy()
df_res.loc[:,'AFFINITY'] = predictions

In [20]:
df_res

Unnamed: 0,SMILES,TARGET,AFFINITY
0,CC1=CC[C@@H]2[C@@H](C1)C3=C(C=C(C=C3OC2(C)C)C(...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.694099
1,CCCCCCC(C)(C)C1=CC(=C2[C@@H]3C[C@@H](CC[C@H]3[...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.453372
2,CCCCCN1C=C(C2=CC=CC=C21)C(=O)C3C(C3(C)C)(C)C,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.217188
3,CC1=C(C2=C3N1[C@@H](COC3=CC=C2)CN4CCOCC4)C(=O)...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,7.315071
4,CCCCCC1=CC(=C2[C@@H]3C=C(CC[C@H]3C(OC2=C1)(C)C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.503256
5,C[C@@H]([C@@H](CC1=CC=C(C=C1)Cl)C2=CC=CC(=C2)C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.355811
6,CC1=C(N(N=C1C(=O)NN2CCCCC2)C3=C(C=C(C=C3)Cl)Cl...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.088691
7,CCC1=C(N(N=C1C(=O)NN2CCCCC2)C3=C(C=C(C=C3)Cl)C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.28292
8,CC1=C(C2=C(N1CCN3CCOCC3)C=C(C=C2)I)C(=O)C4=CC=...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,7.863091
9,CC1=CC=C(C=C1)CN2C(=CC(=N2)C(=O)N[C@H]3[C@]4(C...,MKSILDGLADTTFRTITTDLLYVGSNDIQYEDIKGDMASKLGYFPQ...,8.430277


# 3. Target-aware Drug Generation

In [21]:
def create_dataset2(seqs_smi, seqs_prot, affinity):
    smile_graph = {}
    for smi in seqs_smi:
        if not smi in smile_graph:
            g = smile_to_graph(smi)
            smile_graph[smi] = g
        
    XD = np.asarray(seqs_smi)
    XT = np.asarray([seq_cat(aa) for aa in seqs_prot])
    Y = np.asarray(affinity)
    
    name = 'tmp'
    data = TestbedDataset2(
        root='data',
        dataset=name,
        xd=XD,
        xt=XT,
        y=Y,
        smile_graph=smile_graph
    )
    return data

In [22]:
test_data = create_dataset2(df_res['SMILES'], df_res['TARGET'], df_res['AFFINITY'])

Preparing data in Pytorch Format: 1/24
Preparing data in Pytorch Format: 2/24
Preparing data in Pytorch Format: 3/24
Preparing data in Pytorch Format: 4/24
Preparing data in Pytorch Format: 5/24
Preparing data in Pytorch Format: 6/24
Preparing data in Pytorch Format: 7/24
Preparing data in Pytorch Format: 8/24
Preparing data in Pytorch Format: 9/24
Preparing data in Pytorch Format: 10/24
Preparing data in Pytorch Format: 11/24
Preparing data in Pytorch Format: 12/24
Preparing data in Pytorch Format: 13/24
Preparing data in Pytorch Format: 14/24
Preparing data in Pytorch Format: 15/24
Preparing data in Pytorch Format: 16/24
Preparing data in Pytorch Format: 17/24
Preparing data in Pytorch Format: 18/24
Preparing data in Pytorch Format: 19/24
Preparing data in Pytorch Format: 20/24
Preparing data in Pytorch Format: 21/24
Preparing data in Pytorch Format: 22/24
Preparing data in Pytorch Format: 23/24
Preparing data in Pytorch Format: 24/24


In [23]:
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=1,
    shuffle=False,
    collate_fn=collate
)

## Evaluate the model

In [24]:
generated = []

with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
        y = tokenizer.get_text(model.generate(data.to(device)))
        generated.append(y[0])

Testing:   0%|          | 0/24 [00:00<?, ?it/s]

In [25]:
generated

['CC(=O)NCCCCCCNC(=O)[C@H](Cc1ccc(O)cc1)NC(=O)N1CCC(N2CCCC2)CC1)C(=O)O',
 'CCCCCCC(C)(C)c1ccc([C@@H]2C[C@H](O)CC[C@H]2CCCO)c(O)c1',
 'CC(C)CCC(C)(C)c1ccc(N2CCN(CCC(=O)NC(C)C)CC2)cc1',
 'O=C(CN1CCCC2(CCN(Cc3ccccc3)CC2)CC1)c1ccc(F)cc1',
 'CC(C)CCC(C)(C)c1ccc(N2CCN(CCC(=O)NC(C)C)CC2)cc1',
 'CC(C)(C)C(=O)CN1C(=O)N(CC(=O)Nc2cccc(-c3nc(=O)[nH][nH]3)c2)C(=O)N(C2CCCCCC2)c2ccccc21',
 'O=C(NCCCc1ccccc1)NCC1CCN(CCCc2ccccc2)CC1',
 'CC(C)CCN1CCC(NC(=O)N(C)Cc2ccccc2)C(=O)Nc2cccnc2)CC1',
 'COc1cccc(OC)c1CNC(=O)CSc1nc2ccccc2n1C',
 'Cc1cc(C)[nH]c2c1C(=O)NCCCCCNC(=O)CCSc1ccccc1',
 'CC1(C)CC(N)=C(CN)NC(=O)NCCCOc2ccc(S(=O)(=O)N3CCCCC3)cc2)C1',
 'CC(C)CCC(C)(C)c1ccc(N2CCN(CCC(=O)NC(C)C)CC2)cc1',
 'CC(C)CCN1CCC(NC(=O)N(CC(=O)Nc2ccccc2)c2ccc(Cl)cc2)CC1',
 'CC(C)CC(CC)C1CCN(CC(=O)N(C)Cc2ccc(N=C(C)C)cc2)C(=O)N(C2CCCCC2)c2ccccc21',
 'CCCCCCC(C)(C)c1ccc([C@@H]2C[C@H](O)CC[C@H]2CCCO)c(O)c1',
 'O=C(CN1CCCC(Cc2ccccc2)CC1)Nc1ccc(N2CCCC2)cc1',
 'CC(C)CC(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](CC(C)C)[C@@H](O)CC(=O)O)C(=O)