In [None]:
!pip install torch-scatter==2.1.1 -f https://data.pyg.org/whl/torch-2.1.2+cpu.html
!pip install torch_geometric




In [None]:
import os, random, numpy as np, torch
import pandas as pd


from torch import nn
from torch.optim import AdamW
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, GATConv, global_mean_pool
from torch_scatter import scatter


from sklearn.metrics import roc_auc_score, average_precision_score, f1_score


!pip install rdkit
from rdkit import Chem
from rdkit.Chem import AllChem


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEEDS = [0, 1, 2, 3, 4]



In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
import torch.serialization
from torch_geometric.data import Data


torch.serialization.add_safe_globals([Data])


graphs = torch.load("qs_graphs.pt", weights_only=False)
df = pd.read_csv("qs_inhibitors_cleaned.csv")


labels = torch.tensor(df["activity_label"].values, dtype=torch.float)
smiles = df["smiles_canonical"].values

In [None]:
def morgan_fp(smile, n_bits=2048):
    mol = Chem.MolFromSmiles(smile)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits)
    return torch.tensor(fp, dtype=torch.float)


fps = torch.stack([morgan_fp(s) for s in smiles])



In [None]:
scaffold_to_id = {s: i for i, s in enumerate(np.unique(smiles))}
scaffold_ids = torch.tensor([scaffold_to_id[s] for s in smiles])


unique_scaffolds = np.array(list(scaffold_to_id.keys()))
np.random.shuffle(unique_scaffolds)


n = len(unique_scaffolds)
train_s = unique_scaffolds[:int(0.7*n)]
val_s   = unique_scaffolds[int(0.7*n):int(0.85*n)]
test_s  = unique_scaffolds[int(0.85*n):]


def split_idx(scafs):
    return np.where(np.isin(smiles, scafs))[0]


SPLITS = {
    "train": split_idx(train_s),
    "val":   split_idx(val_s),
    "test":  split_idx(test_s)
}

In [None]:
class QSDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, labels, fps, scaffold_ids, indices):
        self.graphs = graphs
        self.labels = labels
        self.fps = fps
        self.scaffold_ids = scaffold_ids
        self.indices = indices


    def __len__(self):
        return len(self.indices)


    def __getitem__(self, i):
        idx = self.indices[i]
        g = self.graphs[idx]
        g.y = self.labels[idx]
        g.fp = self.fps[idx]
        g.scaffold_id = self.scaffold_ids[idx]
        return g

In [None]:
def augment_graph(data, node_mask_p=0.15, edge_drop_p=0.1):
    data = data.clone()


    mask = torch.rand(data.x.size(0), device=data.x.device) < node_mask_p
    data.x[mask] = 0


    keep = torch.rand(data.edge_index.size(1), device=data.edge_index.device) > edge_drop_p
    data.edge_index = data.edge_index[:, keep]
    data.edge_attr = data.edge_attr[keep]


    return data

In [None]:
class HybridEncoder(nn.Module):
    def __init__(self, in_dim, edge_dim, hidden=128):
        super().__init__()


        self.gin = GINConv(
            nn.Sequential(
                nn.Linear(in_dim, hidden),
                nn.ReLU(),
                nn.Linear(hidden, hidden)
            )
        )


        self.gat = GATConv(
            in_dim, hidden // 4,
            heads=4, edge_dim=edge_dim
        )


        self.norm = nn.LayerNorm(hidden)


    def forward(self, data):
        h1 = self.gin(data.x, data.edge_index)
        h2 = self.gat(data.x, data.edge_index, data.edge_attr)


        h = self.norm(h1 + h2)
        g = global_mean_pool(h, data.batch)
        return g, h

In [None]:
class QSGNN(nn.Module):
    def __init__(self, in_dim, edge_dim, fp_dim=2048):
        super().__init__()
        self.encoder = HybridEncoder(in_dim, edge_dim)


        self.fp_proj = nn.Linear(fp_dim, 128)


        self.projector = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )


        self.classifier = nn.Linear(256, 1)


    def forward(self, data):
        g_emb, node_emb = self.encoder(data)
        fp_emb = self.fp_proj(data.fp.view(data.num_graphs, -1))


        fused = torch.cat([g_emb, fp_emb], dim=1)
        z = self.projector(fused)
        logits = self.classifier(fused).squeeze()
        return z, node_emb, logits

In [None]:
def nt_xent(z1, z2, temp=0.5):
    z1 = nn.functional.normalize(z1)
    z2 = nn.functional.normalize(z2)
    logits = z1 @ z2.T / temp
    labels = torch.arange(z1.size(0), device=z1.device)
    return nn.CrossEntropyLoss()(logits, labels)


def scaffold_contrastive(z, scaffold_ids):
    z = nn.functional.normalize(z)
    sim = z @ z.T
    same = scaffold_ids.unsqueeze(0) == scaffold_ids.unsqueeze(1)
    mask = ~torch.eye(len(z), device=z.device).bool()
    return -(sim[same & mask].mean() - sim[~same & mask].mean())

In [None]:
def pretrain(model, loader, epochs=30):
    opt = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    model.train()


    for _ in range(epochs):
        for g in loader:
            g = g.to(DEVICE)
            g1, g2 = augment_graph(g), augment_graph(g)


            z1, _, _ = model(g1)
            z2, _, _ = model(g2)



            scaf = g.scaffold_id
            loss = nt_xent(z1, z2) + scaffold_contrastive(z1, scaf)


            opt.zero_grad()
            loss.backward()
            opt.step()


In [None]:
def finetune(model, loader, train_labels, epochs=40):
    pos_weight = torch.tensor(
        [(train_labels == 0).sum() / (train_labels == 1).sum()],
        device=DEVICE
    )
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    opt = AdamW(model.parameters(), lr=1e-3)
    model.eval()

    model.train()
    for _ in range(epochs):
        for g in loader:
            g = g.to(DEVICE)
            _, _, logits = model(g)
            loss = criterion(logits, g.y)
            opt.zero_grad()
            loss.backward()
            opt.step()


In [None]:
def evaluate(model, loader):
    model.eval()
    y, p = [], []


    with torch.no_grad():
        for g in loader:
            g = g.to(DEVICE)
            _, _, logits = model(g)
            y.append(g.y.cpu())
            p.append(torch.sigmoid(logits).cpu())


    y = torch.cat(y)
    p = torch.cat(p)


    return {
        "ROC_AUC": roc_auc_score(y, p),
        "PR_AUC": average_precision_score(y, p),
        "F1": f1_score(y, p > 0.5)
    }

In [None]:
results = []


in_dim = graphs[0].x.shape[1]
edge_dim = graphs[0].edge_attr.shape[1]


for seed in SEEDS:
    set_seed(seed)


    model = QSGNN(in_dim, edge_dim).to(DEVICE)


    train_ds = QSDataset(graphs, labels, fps, scaffold_ids, SPLITS["train"])
    val_ds   = QSDataset(graphs, labels, fps, scaffold_ids, SPLITS["val"])
    test_ds  = QSDataset(graphs, labels, fps, scaffold_ids, SPLITS["test"])


    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    test_loader  = DataLoader(test_ds, batch_size=32)


    pretrain(model, train_loader)
    finetune(model, train_loader, labels[SPLITS["train"]])


    results.append(evaluate(model, test_loader))


df = pd.DataFrame(results)
print(df)
print("\nMean:\n", df.mean())
print("\nStd:\n", df.std())

    ROC_AUC    PR_AUC        F1
0  0.869048  0.887992  0.827586
1  0.875000  0.904960  0.827586
2  0.880952  0.909217  0.827586
3  0.880952  0.910372  0.827586
4  0.869048  0.896032  0.827586

Mean:
 ROC_AUC    0.875000
PR_AUC     0.901715
F1         0.827586
dtype: float64

Std:
 ROC_AUC    5.952381e-03
PR_AUC     9.517963e-03
F1         1.241267e-16
dtype: float64


In [None]:
import pandas as pd


dde = pd.read_csv("DDEB.smi", sep=r"\s+", engine="python")
cae = pd.read_csv("CAED.smi", sep=r"\s+", engine="python")

print("DDEB size:", len(dde))
print("CAED size:", len(cae))


zinc_all = pd.concat([dde, cae], ignore_index=True)


zinc_all = zinc_all.drop_duplicates(subset="smiles").reset_index(drop=True)

print("Combined unique molecules:", len(zinc_all))


zinc_all.to_csv("ZINC_combined_66k.smi", sep=" ", index=False)


DDEB size: 19585
CAED size: 79643
Combined unique molecules: 69295


In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from torch_geometric.data import Data

def mol_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None


    atom_feats = []
    for atom in mol.GetAtoms():
        atom_feats.append([
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization().real,
            atom.GetIsAromatic(),
            atom.GetTotalNumHs()
        ])
    x = torch.tensor(atom_feats, dtype=torch.float)


    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_index += [[i, j], [j, i]]


        bond_feats = [
            float(bond.GetBondType() == Chem.BondType.SINGLE),
            float(bond.GetBondType() == Chem.BondType.DOUBLE),
            float(bond.GetBondType() == Chem.BondType.TRIPLE),
            float(bond.GetBondType() == Chem.BondType.AROMATIC),
            float(bond.GetIsConjugated()),
            float(bond.IsInRing())
        ]
        edge_attr += [bond_feats] * 2

    edge_index = torch.tensor(edge_index).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)


    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
    fp = torch.tensor(fp, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, fp=fp)


In [None]:
graphs = []
valid_smiles = []

for smi in zinc_all["smiles"]:
    g = mol_to_graph(smi)
    if g is not None:
        graphs.append(g)
        valid_smiles.append(smi)

print("Graphs created:", len(graphs))



Graphs created: 69292




In [None]:
from torch_geometric.loader import DataLoader
import torch.nn.functional as F

model.eval()

loader = DataLoader(graphs, batch_size=64, shuffle=False)

scores = []

with torch.no_grad():
    for batch in loader:
        batch = batch.to(DEVICE)
        _, _, logits = model(batch)
        probs = torch.sigmoid(logits)
        scores.extend(probs.cpu().numpy())


In [None]:
zinc_all = zinc_all.iloc[:len(scores)].copy()
zinc_all["predicted_activity"] = scores


top_hits = zinc_all.sort_values(
    by="predicted_activity",
    ascending=False
)

top_hits.head(20)


Unnamed: 0,smiles,zinc_id,predicted_activity
8830,N#Cc1ccc(N2CCC(Nc3ccncc3C(N)=O)CC2)nc1,97140690.0,0.999861
5266,CCCCn1c(=O)nc(O)c2c(C(C)C)c(C#N)c(N)nc21,24632330.0,0.999833
18443,COc1cc(/C=N/Nc2ncnc(N)c2[N+](=O)[O-])ccc1O,5446906.0,0.999816
9622,CN(CCCc1n[nH]c(N)c1C#N)c1ccc(F)cc1[N+](=O)[O-],40468550.0,0.999745
361,Cc1ccnc(N2CC[C@H](N(C)Cc3nccc(N)n3)C2)c1C#N,372519200.0,0.999719
3154,Cc1ccnc(N2CC[C@@H](N(C)Cc3nccc(N)n3)C2)c1C#N,372519200.0,0.999719
7475,Cc1nc([C@@H]2CCO[C@@H]2CN(C)c2ncc(C#N)cc2F)n[nH]1,1775960000.0,0.999646
4539,CN(CC(=O)Nc1c(O)cccc1F)c1ccc(C#N)cn1,83857610.0,0.999609
13158,N#Cc1ccccc1OCC(=O)N[C@@H]1CCN(c2cccnc2)C1,370807300.0,0.99958
6933,N#Cc1ccccc1OCC(=O)N[C@H]1CCN(c2cccnc2)C1,370807300.0,0.99958


In [None]:
import pandas as pd
import torch

model.eval()

results = []


for idx, current_smiles in enumerate(valid_smiles):

    current_zinc_id = zinc_all['zinc_id'].iloc[idx]
    prob = scores[idx]

    results.append({
        "index": idx,
        "smiles": current_smiles,
        "zinc_id": current_zinc_id,
        "predicted_activity": prob
    })

df = pd.DataFrame(results)
df = df.sort_values("predicted_activity", ascending=False).reset_index(drop=True)

In [None]:
df["rank"] = df.index + 1


df = df[df["predicted_activity"] > 0.90]


In [None]:
top_100 = df.head(100)
top_20 = df.head(20)
bottom_50 = df.tail(50)

top_100.to_csv("top_100_for_docking.csv", index=False)
top_20.to_csv("top_20_visualization.csv", index=False)
bottom_50.to_csv("bottom_50_negative_controls.csv", index=False)


In [None]:
def export_smiles(df, filename):
    with open(filename, "w") as f:
        for s in df["smiles"]:
            f.write(s + "\n")

export_smiles(top_100, "top_100_smiles.smi")
export_smiles(bottom_50, "bottom_50_smiles.smi")


In [None]:
print("Top hit probability:", top_100.iloc[0]["predicted_activity"])
print("Total docking candidates:", len(top_100))


Top hit probability: 0.999861
Total docking candidates: 100
