In [1]:
import os
os.chdir('../..')

In [2]:
import numpy as np
import pandas as pd
import torch

from rdkit.Chem import MolFromSmiles, MolFromInchi, MolToInchi
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG

from molexplain.utils import MODELS_PATH, PROCESSED_DATA_PATH, DEVICE
from molexplain.vis import molecule_importance

Using backend: pytorch


In [3]:
from molexplain.net import MPNNPredictor

model = MPNNPredictor(node_in_feats=46,
                      edge_in_feats=10,
                      global_feats=4,
                      n_tasks=1).to(DEVICE) 
model.load_state_dict(torch.load(os.path.join(MODELS_PATH, "CYP3A4_noHs.pt"), map_location=DEVICE))

df = pd.read_csv('../cyp/CYP3A4.csv', header=0, sep=';')
smiles = df['SMILES'].to_numpy()

inchis = []
invalid_idx = []

for idx, sm in enumerate(smiles):
    try:
        mol = MolFromSmiles(sm)
        inchi = MolToInchi(mol)
        mol_back = MolFromInchi(inchi)
        if mol_back is not None:
            inchis.append(inchi)
        else:
            invalid_idx.append(idx)
    except:
        invalid_idx.append(idx)
        continue


inchis = np.array(inchis)
values = np.array([1.0 if l == 'Active' else 0.0 for l in df['Class']])[:, np.newaxis]
value_idx = np.setdiff1d(np.arange(len(values)), np.array(invalid_idx))
values = values[value_idx, :]

In [4]:
idx = 10
example_inchi, example_label = inchis[idx], values[idx]

print(example_inchi)
print(example_label)

InChI=1S/C28H33ClNOP/c1-3-11-23-21-28(23,20-4-2)27(22-16-18-24(29)19-17-22)30-32(31,25-12-7-5-8-13-25)26-14-9-6-10-15-26/h5-10,12-19,23,27H,3-4,11,20-21H2,1-2H3,(H,30,31)/t23-,27?,28-/m1/s1
[1.]


In [None]:
from tqdm import tqdm

IMG_DIR = 'imgs_cyp_noHs'
os.makedirs(IMG_DIR, exist_ok=True)


for idx in tqdm(range(len(inchis))):
    mol = MolFromInchi(inchis[idx])
    svg, _, _, _, global_importance = molecule_importance(mol,
                                                          model,
                                                          task=0,
                                                          vis_factor=5,
                                                          addHs=False)

    with open(os.path.join(IMG_DIR, f'{idx}.svg'), 'w+') as handle:
        handle.write(svg)

 40%|████      | 3671/9120 [2:39:58<4:00:26,  2.65s/it]

In [None]:
np.save('/home/jose/cyp/inchis.npy', arr=inchis)
np.save('/home/jose/cyp/values.npy', arr=values)