In [1]:
#if not already installed
!python3 -m pip install h5py sklearn



In [1]:
#imports
import os
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import functools
import matplotlib.pyplot as plt
import math
import scipy

from sklearn import svm, datasets
from sklearn.metrics import auc
from sklearn.metrics import plot_roc_curve, roc_curve, roc_auc_score
from sklearn.model_selection import StratifiedKFold

import rdkit
from rdkit.Chem import Descriptors
from rdkit import Chem
from rdkit.Chem import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import RDConfig
from rdkit import rdBase
from rdkit.Chem.Draw import IPythonConsole

In [46]:
embeddings_path = "embeddings/"

# Select A Dataset
* Datasets are h5py binary files
* Each contains an array with the Transformer embeddings for each molecule
* And additional arrays with SMILES strings, binding labels, etc.

In [47]:
#how many epochs was the transformer trained before generating embeddings?
epoch_id = "2"

assays = os.listdir(embeddings_path)
if epoch_id is not None:
    #assays = [assay for assay in assays if assay.split(".")[0].split("_")[-1] == str(epoch_id)]
    assays = [assay for assay in assays if assay.split(".")[0].split("_")[-1] == str(epoch_id)]
print(assays)

['hiv1_protease_2.hdf5']


In [11]:
#Set assay_idx to the index of the desired h5py file in the above list
assay_idx = 0
assay_path = os.path.join(embeddings_path, assays[assay_idx])
print(assay_path)

embeddings/hiv1_protease_2.hdf5


# Load and Prepare Data

In [12]:
#load dataset
assay = h5py.File(assay_path, 'r')

In [13]:
labels_result = assay['result'] #numeric assay result
labels_binding = assay['binding'] #0 or 1 ("not binding" / "binding")
smiles_enc = assay['smiles'] #smiles strings

In [15]:
binding = labels_binding[:]

In [16]:
binding.sum() #how many "binding" molecules

2159

In [17]:
def undersample(idxs, labels, ratio=1):
    no_bind_idxs = idxs[labels[idxs]==0]
    bind_idxs = idxs[labels[idxs]==1]
    min_len = min(len(bind_idxs), len(no_bind_idxs)) * ratio
    
    np.random.shuffle(no_bind_idxs)
    np.random.shuffle(bind_idxs)
    
    no_bind_idxs = no_bind_idxs[:min_len]
    bind_idxs = bind_idxs[:min_len]
    
    idxs = np.concatenate((no_bind_idxs, bind_idxs))
    np.random.shuffle(idxs)
    
    return idxs

In [18]:
#increase ratio to include more "non-binding" samples
idxs = undersample(np.arange(len(binding)), binding, ratio=1)
idxs.shape

(4318,)

In [19]:
def reduce_data(data, idxs):
    return [torch.tensor(np.stack([d[idx] for idx in idxs])) for d in data]

In [20]:
#create pytorch tensors from the arrays
sm, y = reduce_data([smiles_enc, binding], idxs)
print(sm.shape, y.shape)

torch.Size([4318, 256, 512]) torch.Size([4318, 256]) torch.Size([4318])


# Calculate Tanimoto Similarity

In [None]:
fps = []
found_idxs = []
for i in range(sm.shape[0]):
    smiles = ''.join([chr(round(sm[i,char].item() * 98) + 32) for char in range(1, sm.shape[1])]).strip(chr(129))
    try:
        mol = Chem.MolFromSmiles(smiles)
        assert mol is not None
        fps.append(Chem.RDKFingerprint(mol))
        if i % 1000 == 0:
            print(i)
        found_idxs.append(i)
    except:
        print("could not load")
found_idxs = torch.tensor(found_idxs)
sm, y = sm[found_idxs], y[found_idxs]
#f = torch.tensor(np.stack(bits)).float()

In [None]:
similarities = torch.zeros((sm.shape[0], sm.shape[0]), dtype=torch.float)
for i in range(sm.shape[0]):
    for j in range(i, sm.shape[0]):
        similarity = DataStructs.FingerprintSimilarity(fps[i], fps[j])
        similarities[i,j] = similarity
        similarities[j,i] = similarity
    print(i)

In [None]:
cv = StratifiedKFold(n_splits=10)
mean_similarities = []
for i, (train_idxs, test_idxs) in enumerate(cv.split(sm, y)):
    test_similarities = []
    for test_idx in test_idxs:
        max_similarity = 0
        for train_idx in train_idxs:
            max_similarity = max(max_similarity, similarities[test_idx, train_idx])
        test_similarities.append(max_similarity)
    mean_similarities.append(np.array(test_similarities).mean())
print(mean_similarities)
print(np.array(mean_similarities).mean())