In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import AllChem

from rad.construction import getGraphs
from rad.traversal import traverseHNSW

# Disable rdkit error logging to keep output clean
RDLogger.DisableLog('rdApp.*')  

### Load the DUDEZ DOCK dataset

In [None]:
!wget https://zenodo.org/records/10989077/files/dudez_dock_scores.pkl

In [None]:
with open('dudez_dock_scores.pkl', 'rb') as f:
    dudez_data = pickle.load(f)

### Set parameters for fingerprints and generate them

In [None]:
FP_LENGTH = 1024
FP_RADIUS = 2

In [None]:
dudez_fps = []
dudez_scores = {}
node_id = 0

for zid in tqdm(dudez_data, total=len(dudez_data), desc="Generating Fingeprints"):
    smi = dudez_data[zid][0]
    scores = dudez_data[zid][1]

    # Some smiles will fail molecule generation. We just skip them
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        continue

    # Convert rdkit bit vect fingerprint to numpy array
    arr = np.zeros((1,), dtype=np.uint8)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=FP_RADIUS, nBits=FP_LENGTH)
    DataStructs.ConvertToNumpyArray(fp, arr)

    # IMPORTANT: Make sure to pack bit fingerprints - it vastly speeds up HNSW construction
    dudez_fps.append(np.packbits(arr))
    dudez_scores[node_id] = scores
    node_id += 1

dudez_fps = np.array(dudez_fps)

### Set parameters for HNSW and construct it

In [None]:
EF_CONSTRUCTION = 400
M = 16

In [None]:
hnsw_layer_graphs = getGraphs(dudez_fps, ef_construction=EF_CONSTRUCTION, M=M)

### Traverse HNSW using scores from 1 receptor

In [None]:
RECEPTOR = "AA2AR"
NUM_TO_TRAVERSE = 100_000 # Maximum number of molecules to score

In [None]:
def score_fn(node_id):
    if RECEPTOR in dudez_scores[node_id]:
        return dudez_scores[node_id][RECEPTOR]
    else:
        return np.inf

In [None]:
traversed_nodes = traverseHNSW(hnsw_layer_graphs, score_fn, NUM_TO_TRAVERSE)

### Graph Enrichment Plots

In [None]:
VIRTUAL_ACTIVE_CUTOFF = 100

In [None]:
receptor_scores = []
for node_id, scores in dudez_scores.items():
    if RECEPTOR in scores:
        receptor_scores.append((node_id, scores[RECEPTOR]))
    else:
        receptor_scores.append((node_id, np.inf))

receptor_scores.sort(key=lambda x:x[1])

virtual_actives = [node_id for node_id,_ in receptor_scores[:VIRTUAL_ACTIVE_CUTOFF]]

In [None]:
plot_mols_traversed = np.linspace(0,NUM_TO_TRAVERSE,100, dtype=int)
plot_virtual_actives_recovered = []
for n in plot_mols_traversed:
    mols_traversed = set(list(traversed_nodes.keys())[:n])
    virtual_actives_recovered = sum([True if node_id in mols_traversed else False for node_id in virtual_actives])
    plot_virtual_actives_recovered.append(virtual_actives_recovered/len(virtual_actives))

In [None]:
plt.plot(plot_mols_traversed, plot_virtual_actives_recovered)
plt.ylim(0,1)
plt.xlabel("Number of Molecules Scored")
plt.ylabel("Percent of Virtual Actives Found")
plt.title(f"Enrichment Plot for RAD of {RECEPTOR}")