# Import and Setup

In [1]:
# Clone the entire repo
!git clone -l -s https://github.com/manbaritone/GraphEGFR .proj.temp/
%rm .proj.temp/run.ipynb
%mv .proj.temp/* ./
%rm -r .proj.temp/

Cloning into '.proj.temp'...
remote: Enumerating objects: 178, done.[K
remote: Counting objects: 100% (178/178), done.[K
remote: Compressing objects: 100% (112/112), done.[K
remote: Total 178 (delta 100), reused 123 (delta 60), pack-reused 0[K
Receiving objects: 100% (178/178), 541.91 KiB | 2.24 MiB/s, done.
Resolving deltas: 100% (100/100), done.


In [2]:
!pip install -q condacolab
import condacolab
import sys
condacolab.install()
!echo $PYTHONPATH
%env PYTHONPATH=/usr/local/bin/python3.10

⏬ Downloading https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:13
🔁 Restarting kernel...
/env/python
env: PYTHONPATH=/usr/local/bin/python3.10


In [1]:
import condacolab
condacolab.check()

✨🍰✨ Everything looks OK!


In [None]:
# Install Python dependencies
!pip install scikit-learn==1.2.2
!pip install imblearn
!conda install conda-forge::openbabel
!python -m pip install rdkit
!python -m pip install deepchem==2.5.0
!python -m pip install JPype1
!pip install torch==2.0.0
!pip install torch_geometric
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!python -m pip install torchmetrics
!pip install dgl==1.1.3 -f https://data.dgl.ai/wheels/cu118/repo.html
!pip install dgllife


In [None]:
# Download large required project files for the test
!wget https://zenodo.org/records/11118070/files/GraphEGFR.tar.gz
!tar -xvf /content/GraphEGFR.tar.gz --warning=no-unknown-keyword -C .
!rm GraphEGFR/*.ipynb
!mv GraphEGFR/* ./
!rm -r GraphEGFR

In [11]:
import pandas as pd
import os
import torch
from torch_geometric.loader import DataLoader

from graphegfr.models import GraphEGFR
from graphegfr.configs import Configs
from graphegfr.fingerprint import Fingerprint
from graphegfr.featurizer import generate_npdata, clean_smiles
from graphegfr.dataset import load_dataset

In [12]:
target_dict = {
    r'MTL_HER124': ['HER1','HER2','HER4'],
    r'MTL_ALL_WT_MT': ['HER1','HER2','HER4','T790M_L858R','L858R','delE746_A750','T790M'],
    r'MTL_HER1_ALL_MT': ['HER1','T790M_L858R','L858R','delE746_A750','T790M'],
    r'MTL_ALL_MT': ['T790M_L858R','L858R','delE746_A750','T790M'],
    r'HER1':['HER1'],
    r'HER2':['HER2'],
    r'HER4':['HER4'],
    r'T790M_L858R':['T790M_L858R'],
    r'L858R':['L858R'],
    r'delE746_A750':['delE746_A750'],
    r'T790M':['T790M']
}

# Main Section

In [21]:
# enter target here
target = "T790M_L858R"
print_architecture = False

datapath = ".temp" # path to save generated fingerprint
smiles_path = f"resources/LigEGFR/data/{target}.csv"

# smiles_raw can be of any iterable type containing SMILES
smiles_raw = pd.read_csv(smiles_path)["SMILES_NS"] # series of SMILES

smiles = clean_smiles(smiles_raw)
smiles = smiles.iloc[:5] # remove this row if you don't want to sample
for smi in smiles.tolist():
    print(smi)

Number of defect: 0
C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1
C#Cc1nn([C@@H]2CCCN(C(=O)C=C)C2)c2ncnc(N)c12
C=C=CC(=O)Nc1cc(Nc2ncc(C)c(-c3cn(C)c4ccccc34)n2)c(OC)cc1N(C)CCN(C)C
C=C=CC(=O)Nc1cc(Nc2ncc(OC)c(-c3cn(C)c4ccccc34)n2)c(OC)cc1N(C)CCN(C)C
C=C=CC(=O)Nc1cc(Nc2nccc(-c3c[nH]c4ccccc34)n2)c(OC)cc1N(C)CCN(C)C


In [None]:
print("Generating fingerprint...")
Fingerprint(smiles, datapath)
adj, feature, graph, edge = generate_npdata(smiles, datapath)

print("Done")
# print(adj.shape)
# print(feature.shape)
# print(graph.shape)
# print(len(edge))

In [15]:
# Generated from the above cell
fpc = pd.read_csv(f'{datapath}/fingerprint-nonhash.csv').to_numpy()
fpf = pd.read_csv(f'{datapath}/fingerprint-hash.csv').to_numpy()

fingfeaf=fpf.shape[-1]
fingfeac=fpc.shape[-1]
fpfs=[]
fpcs=[]
Label = [None] * len(edge)
for i in fpf:
    fpfs.append(torch.FloatTensor(i))
for i in fpc:
    fpcs.append(torch.FloatTensor(i).unsqueeze(1))

In [16]:
dataset, smiles_list = load_dataset(adj, smiles, Label, fpfs, fpcs)

In [17]:
test_loader = DataLoader(dataset,batch_size=1,shuffle=False)

In [18]:
# list all possible targets
for i in target_dict:
    print(i)

MTL_HER124
MTL_ALL_WT_MT
MTL_HER1_ALL_MT
MTL_ALL_MT
HER1
HER2
HER4
T790M_L858R
L858R
delE746_A750
T790M


In [19]:
configs = Configs.parse(f"configs/sample/{target}-conf.json")
hpconfig = configs['hyperparam']
num_atom_features = hpconfig["num_atom_features"]
edge_dim = hpconfig["edge_dim"]
fingerprint_dim = hpconfig["fingerprint_dim"]
num_layers = hpconfig["num_layers"]
num_timesteps = hpconfig["num_timesteps"]
dropout = 0 # not used in eval mode regardless

model = GraphEGFR(num_atom_features,edge_dim, fingerprint_dim,
                  num_layers, num_timesteps, dropout, fingfeaf,
                  fingfeac, configs)
state_dict = torch.load(f"./state_dict/{target}.pt")
model.load_state_dict(state_dict)
model.cpu().eval()
if print_architecture:
    print("== Model Architecture ==")
    print("Target:", target)
    print("Model:\n",model)
else:
    print()




In [20]:
records = {"smiles":[]}
if "MTL" not in target:
    index_ans = target_dict[configs["target"]].index(target)
    actual_target = target_dict[configs["target"]][index_ans]
    records[actual_target] = []
else:
    for t in target_dict[configs["target"]]:
        records[t] = []
for data, smi in zip(test_loader, smiles_list):
    if "MTL" not in target:
        value = model(data)[0,index_ans].item()
        records[actual_target].append(value)
    else:
        value = model(data)[0,:].tolist()
        for i, v in enumerate(value):
            records[target_dict[configs["target"]][i]].append(v)
    records["smiles"].append(smi)
df_records = pd.DataFrame(records)
df_records

Unnamed: 0,smiles,T790M_L858R
0,C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1,5.054157
1,C#Cc1nn([C@@H]2CCCN(C(=O)C=C)C2)c2ncnc(N)c12,6.057258
2,C=C=CC(=O)Nc1cc(Nc2ncc(C)c(-c3cn(C)c4ccccc34)n...,8.891642
3,C=C=CC(=O)Nc1cc(Nc2ncc(OC)c(-c3cn(C)c4ccccc34)...,8.990313
4,C=C=CC(=O)Nc1cc(Nc2nccc(-c3c[nH]c4ccccc34)n2)c...,8.798269
