In [None]:
import random
import pickle
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import AllChem

from rad.traverser import RADTraverser
from usearch.index import Index

# Disable rdkit error logging to keep output clean
RDLogger.DisableLog('rdApp.*')  

In [None]:
# Set the number of workers for fingerprint generation and HNSW traversal
N_WORKERS = 8

### Load the DUDEZ DOCK dataset

In [None]:
!curl -O https://zenodo.org/records/10989077/files/dudez_dock_scores.pkl

In [None]:
with open('dudez_dock_scores.pkl', 'rb') as f:
    dudez_data = pickle.load(f)

In [None]:
# Zinc ID, SMILES, score dictionary
data_list = [(zid, dudez_data[zid][0], dudez_data[zid][1]) for zid in dudez_data]

# Shuffle the data
random.shuffle(data_list)

### Set parameters for fingerprints and generate them

In [None]:
FP_LENGTH = 1024
FP_RADIUS = 2

In [None]:
def generate_fingerprint(args):
    zid, smi, scores = args
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=FP_RADIUS, nBits=FP_LENGTH)
    arr = np.zeros((FP_LENGTH,), dtype=np.uint8)
    DataStructs.ConvertToNumpyArray(fp, arr)
    return np.packbits(arr), scores

In [None]:
with multiprocessing.Pool(N_WORKERS) as p:
    results = list(tqdm(p.imap(generate_fingerprint, data_list), total=len(data_list), desc="Generating Fingerprints"))

# Filter molecules that failed fingerprint generation
results = list(filter(None, results))

# Format fingerprints and create keys for HNSW
fps, scores = zip(*results)
fps = np.vstack(fps)
keys = np.arange(len(fps))

### Set parameters for HNSW and construct it

In [None]:
EF_CONSTRUCTION = 400
M = 16

In [None]:
hnsw = Index(
    ndim = FP_LENGTH,
    dtype='b1',
    metric='tanimoto',
    connectivity = M,
    expansion_add = EF_CONSTRUCTION
)

In [None]:
hnsw.add(keys, fps, log="Building HNSW")

### Define Scoring Function and Set Up RADTraverser
##### This does two things:
1. Starts a process that handles HNSW neighbor queries
2. Starts a process that runs a redis-server for traversal logic (or connects to an existing redis server)

In [None]:
RECEPTOR = "LCK"

In [None]:
def score_fn(node_key):
    if RECEPTOR in scores[node_key]:
        return scores[node_key][RECEPTOR]
    else:
        return 9999

In [None]:
traverser = RADTraverser(hnsw=hnsw, scoring_fn=score_fn)

### "Prime" the traversal - initialize queue with top layer nodes

In [None]:
traverser.prime()

### Perform the Traversal

In [None]:
NUM_TO_TRAVERSE = 100_000 # Maximum number of molecules to score

In [None]:
# Normally will be able to use >1 worker but there's a bug for now
# This step should still only take about a min or two for traversering 100,000
traverser.traverse(n_workers=1, n_to_score=NUM_TO_TRAVERSE)

### Graph Enrichment Plots

In [None]:
VIRTUAL_ACTIVE_CUTOFF = 100

In [None]:
receptor_scores = []
for node_key, node_score in enumerate(scores):
    if RECEPTOR in node_score:
        receptor_scores.append((node_key, node_score[RECEPTOR]))
    else:
        receptor_scores.append((node_key, 9999))

receptor_scores.sort(key=lambda x:x[1])

virtual_actives = set([node_key for node_key,_ in receptor_scores[:VIRTUAL_ACTIVE_CUTOFF]])

In [None]:
x = []
y = []
mols_traversed = 0
va_found = 0
for key, score in traverser.scored_set:
    mols_traversed += 1
    if key in virtual_actives:
        va_found += 1
    x.append(mols_traversed)
    y.append(va_found)

y = np.array(y) / VIRTUAL_ACTIVE_CUTOFF

In [None]:
plt.plot(x, y)
plt.ylim(0,1)
plt.xlabel("Number of Molecules Scored")
plt.ylabel("Percent of Virtual Actives Found")
plt.title(f"Enrichment Plot for RAD of {RECEPTOR}")

### Shuts down HNSW and redis server processes

In [None]:
traverser.shutdown()