# *De Novo* Molecule Reasoning with PROTON

This notebook enables reasoning over de novo molecules without retraining the PROTON model by:

1. Using **Uni-Mol2**, a pre-trained structure-aware molecular encoder, to generate embeddings from SMILES strings.
2. Training a small **adapter network** to map Uni-Mol2 representations into PROTON's 512-dimensional embedding space.
3. Using the projected embeddings with PROTON's decoder for **drug-disease predictions**.

The adapter is trained using the 8,160 existing drugs in NeuroKG, teaching it to recreate PROTON's learned drug embeddings from their chemical structures.

## Setup and Dependencies

Install molecular encoding dependencies (Uni-Mol2 and RDKit).

In [None]:
# Install dependencies for molecular encoding
# Uncomment and run if not already installed
# !uv add "unimol_tools<0.1.5" "lifelines<0.29.0" rdkit-pypi

In [1]:
import logging
import os
import time
from pathlib import Path

import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import requests
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from src.config import conf
from src.constants import TORCH_DEVICE
from src.dataloaders import load_graph
from src.models import HGT

if any("Arial" in f.name for f in fm.fontManager.ttflist):
    plt.rcParams["font.family"] = "Arial"

_logger = logging.getLogger(__name__)

Define output directories.

In [2]:
# Output directories
DRUGS_DIR = Path("data/neurokg/drugs")
DRUGS_DIR.mkdir(parents=True, exist_ok=True)

OUTPUT_DIR = Path("data/notebooks/molecular_analysis")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Cache file paths
SMILES_CACHE_PATH = DRUGS_DIR / "drug_smiles.csv"
UNIMOL_CACHE_PATH = DRUGS_DIR / "unimol_embeddings.pt"
ADAPTER_PATH = conf.paths.checkpoint.base_dir / "molecular_adapter.pt"

## 2. Load PROTON Model and Drug Data

Read CSVs of nodes and edges.

In [3]:
nodes = pd.read_csv(conf.paths.kg.nodes_path, dtype={"node_index": int}, low_memory=False)
edges = pd.read_csv(
    conf.paths.kg.edges_path, dtype={"edge_index": int, "x_index": int, "y_index": int}, low_memory=False
)

_logger.info(f"Number of nodes: {len(nodes)}")
_logger.info(f"Number of edges: {len(edges) / 2:.0f}")

Load knowledge graph and PROTON model.

In [4]:
pl.seed_everything(conf.seed, workers=True)
kg = load_graph(nodes, edges)
pretrain_model = HGT.load_from_checkpoint(
    checkpoint_path=str(conf.paths.checkpoint.checkpoint_path),
    kg=kg,
    strict=False,
)
pretrain_model.eval()
pretrain_model = pretrain_model.to(TORCH_DEVICE)

Seed set to 42


Load pre-computed PROTON embeddings.

In [5]:
embeddings = torch.load(conf.paths.checkpoint.embeddings_path)
_logger.info(f"Embeddings shape: {embeddings.shape}")

EMBEDDING_DIM = embeddings.shape[1]
_logger.info(f"PROTON embedding dimension: {EMBEDDING_DIM}")

Extract drug nodes and their embeddings.

In [6]:
# Filter to drug nodes
drug_nodes = nodes[nodes["node_type"] == "drug"].copy().reset_index(drop=True)
_logger.info(f"Number of drug nodes: {len(drug_nodes)}")

# Extract drug embeddings from PROTON
drug_indices = drug_nodes["node_index"].values
drug_embeddings = embeddings[drug_indices]
_logger.info(f"Drug embeddings shape: {drug_embeddings.shape}")

## 3. SMILES Data Collection

Fetch SMILES structures for drugs in NeuroKG from PubChem. Results are cached to avoid repeated API calls.

In [21]:
def get_smiles_from_pubchem(drugbank_id: str, timeout: int = 10) -> str | None:
    """Fetch SMILES from PubChem given a DrugBank ID.

    Args:
        drugbank_id: DrugBank identifier (e.g., 'DB00001')
        timeout: Request timeout in seconds

    Returns:
        Isomeric SMILES string if found, None otherwise
    """
    url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drugbank_id}/property/IsomericSMILES/JSON"
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
    }
    try:
        response = requests.get(url, headers=headers, timeout=timeout)
        response.raise_for_status()
        data = response.json()
        return data["PropertyTable"]["Properties"][0]["SMILES"]
    except Exception:
        return None


def fetch_drug_smiles(
    drug_nodes: pd.DataFrame,
    cache_path: Path,
    rate_limit: float = 0.1,
) -> pd.DataFrame:
    """Fetch SMILES for all drugs, using cache if available.

    Args:
        drug_nodes: DataFrame with drug node information
        cache_path: Path to cache CSV file
        rate_limit: Seconds to wait between API calls

    Returns:
        DataFrame with SMILES column added
    """
    # Check for cached data
    if cache_path.exists():
        _logger.info(f"Loading cached SMILES from {cache_path}")
        cached = pd.read_csv(cache_path, dtype={"node_index": int})
        return cached

    # Fetch SMILES from PubChem
    _logger.info("Fetching SMILES from PubChem API...")
    smiles_list = []
    for drug_id in tqdm(drug_nodes["node_id"], desc="Fetching SMILES"):
        smiles = get_smiles_from_pubchem(drug_id)
        smiles_list.append(smiles)
        time.sleep(rate_limit)  # Rate limiting

    # Add SMILES to dataframe
    result = drug_nodes[["node_index", "node_id", "node_name", "node_type"]].copy()
    result["SMILES"] = smiles_list

    # Save cache
    result.to_csv(cache_path, index=False)
    _logger.info(f"Saved SMILES cache to {cache_path}")

    return result

Fetch or load SMILES data.

In [26]:
drug_smiles_df = fetch_drug_smiles(drug_nodes, SMILES_CACHE_PATH)

# Report coverage
total_drugs = len(drug_smiles_df)
valid_smiles = drug_smiles_df["SMILES"].notna().sum()
coverage = 100 * valid_smiles / total_drugs

_logger.info(f"Retrieved SMILES for {valid_smiles}/{total_drugs} drugs ({coverage:.1f}%)")

Fetching SMILES: 100%|██████████████████████████████████████████████| 8160/8160 [33:27<00:00,  4.07it/s]


Filter to drugs with valid SMILES.

In [27]:
# Filter to valid SMILES
valid_drugs_df = drug_smiles_df[drug_smiles_df["SMILES"].notna()].copy().reset_index(drop=True)
_logger.info(f"Drugs with valid SMILES: {len(valid_drugs_df)}")

# Get corresponding PROTON embeddings for valid drugs
valid_drug_indices = valid_drugs_df["node_index"].values
valid_drug_embeddings = embeddings[valid_drug_indices]
_logger.info(f"Valid drug embeddings shape: {valid_drug_embeddings.shape}")

## Uni-Mol2 Molecular Encoding

Use the pre-trained Uni-Mol2 model to encode SMILES strings into 512-dimensional molecular representations.

In [28]:
def encode_molecules_unimol(
    smiles_list: list[str],
    cache_path: Path | None = None,
    batch_size: int = 32,
) -> torch.Tensor:
    """Encode SMILES strings using Uni-Mol2.

    Args:
        smiles_list: List of SMILES strings
        cache_path: Optional path to cache embeddings
        batch_size: Batch size for encoding

    Returns:
        Tensor of molecular embeddings [N, 512]
    """
    # Check cache
    if cache_path is not None and cache_path.exists():
        _logger.info(f"Loading cached Uni-Mol2 embeddings from {cache_path}")
        return torch.load(cache_path)

    # Import Uni-Mol2
    try:
        from unimol_tools import UniMolRepr
    except ImportError as e:
        raise ImportError(
            "Uni-Mol2 not installed. Run: pip install unimol_tools"
        ) from e

    # Initialize Uni-Mol2 model
    _logger.info("Loading Uni-Mol2 model...")
    clf = UniMolRepr(data_type="molecule", remove_hs=False)

    # Encode in batches
    _logger.info(f"Encoding {len(smiles_list)} molecules...")
    all_embeddings = []

    for i in tqdm(range(0, len(smiles_list), batch_size), desc="Encoding molecules"):
        batch = smiles_list[i : i + batch_size]
        # UniMolRepr returns dict with 'cls_repr' key containing [batch, 512] embeddings
        reprs = clf.get_repr(batch, return_atomic_reprs=False)
        # Extract CLS token representation (molecule-level)
        batch_emb = torch.tensor(reprs["cls_repr"], dtype=torch.float32)
        all_embeddings.append(batch_emb)

    embeddings = torch.cat(all_embeddings, dim=0)
    _logger.info(f"Uni-Mol2 embeddings shape: {embeddings.shape}")

    # Save cache
    if cache_path is not None:
        torch.save(embeddings, cache_path)
        _logger.info(f"Saved Uni-Mol2 embeddings to {cache_path}")

    return embeddings

Encode drugs with Uni-Mol2.

In [29]:
# Get SMILES list for valid drugs
smiles_list = valid_drugs_df["SMILES"].tolist()

# Encode with Uni-Mol2
unimol_embeddings = encode_molecules_unimol(smiles_list, cache_path=UNIMOL_CACHE_PATH)
_logger.info(f"Uni-Mol2 embeddings shape: {unimol_embeddings.shape}")

2026-02-01 16:17:28 | unimol_tools/weights/weighthub.py | 32 | INFO | Uni-Mol Tools | Weights will be downloaded to default directory: /Users/an583/Documents/Zitnik_Lab/PROTON-GEM/.venv/lib/python3.11/site-packages/unimol_tools/weights


2026-02-01 16:17:28 | unimol_tools/weights/weighthub.py | 49 | INFO | Uni-Mol Tools | Downloading mol_pre_all_h_220816.pt


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

mol_pre_all_h_220816.pt:   0%|          | 0.00/191M [00:00<?, ?B/s]

2026-02-01 16:17:37 | unimol_tools/weights/weighthub.py | 32 | INFO | Uni-Mol Tools | Weights will be downloaded to default directory: /Users/an583/Documents/Zitnik_Lab/PROTON-GEM/.venv/lib/python3.11/site-packages/unimol_tools/weights


2026-02-01 16:17:37 | unimol_tools/weights/weighthub.py | 49 | INFO | Uni-Mol Tools | Downloading mol.dict.txt


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

mol.dict.txt:   0%|          | 0.00/91.0 [00:00<?, ?B/s]

2026-02-01 16:17:38 | unimol_tools/models/unimol.py | 136 | INFO | Uni-Mol Tools | Loading pretrained weights from /Users/an583/Documents/Zitnik_Lab/PROTON-GEM/.venv/lib/python3.11/site-packages/unimol_tools/weights/mol_pre_all_h_220816.pt


Encoding molecules:   0%|                                                       | 0/217 [00:00<?, ?it/s]2026-02-01 16:17:38 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:06,  4.99it/s]
2026-02-01 16:17:45 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 96.88% of molecules.


2026-02-01 16:17:45 | unimol_tools/data/conformer.py | 191 | INFO | Uni-Mol Tools | Failed conformers indices: [0]


2026-02-01 16:17:45 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 93.75% of molecules.


2026-02-01 16:17:45 | unimol_tools/data/conformer.py | 208 | INFO | Uni-Mol Tools | Failed 3d conformers indices: [0, 2]


2026-02-01 16:17:45 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.10s/it]
Encoding molecules:   0%|▏                                              | 1/217 [00:07<28:10,  7.83s/it]2026-02-01 16:17:46 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:07,  4.43it/s]
2026-02-01 16:17:53 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:17:53 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:17:54 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]
Encoding molecules:   1%|▍                                              | 2/217 [00:16<30:20,  8.47s/it]2026-02-01 16:17:55 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:07,  4.53it/s]
2026-02-01 16:18:02 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:02 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 96.88% of molecules.


2026-02-01 16:18:02 | unimol_tools/data/conformer.py | 208 | INFO | Uni-Mol Tools | Failed 3d conformers indices: [29]


2026-02-01 16:18:02 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.74s/it]
Encoding molecules:   1%|▋                                              | 3/217 [00:27<34:38,  9.71s/it]2026-02-01 16:18:06 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:06,  5.13it/s]
2026-02-01 16:18:12 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 90.62% of molecules.


2026-02-01 16:18:12 | unimol_tools/data/conformer.py | 191 | INFO | Uni-Mol Tools | Failed conformers indices: [4, 5, 7]


2026-02-01 16:18:13 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 90.62% of molecules.


2026-02-01 16:18:13 | unimol_tools/data/conformer.py | 208 | INFO | Uni-Mol Tools | Failed 3d conformers indices: [4, 5, 7]


2026-02-01 16:18:13 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.70s/it]
Encoding molecules:   2%|▊                                              | 4/217 [00:36<32:28,  9.15s/it]2026-02-01 16:18:14 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:05,  5.76it/s]
2026-02-01 16:18:20 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:20 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:18:20 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.19it/s]
Encoding molecules:   2%|█                                              | 5/217 [00:43<29:25,  8.33s/it]2026-02-01 16:18:21 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:06,  4.81it/s]
2026-02-01 16:18:28 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:28 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:18:28 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.37s/it]
Encoding molecules:   3%|█▎                                             | 6/217 [00:52<30:25,  8.65s/it]2026-02-01 16:18:30 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:05,  5.99it/s]
2026-02-01 16:18:36 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:36 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:18:36 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.38it/s]
Encoding molecules:   3%|█▌                                             | 7/217 [00:58<27:32,  7.87s/it]2026-02-01 16:18:37 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:05,  5.67it/s]
2026-02-01 16:18:43 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:43 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:18:43 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.28s/it]
Encoding molecules:   4%|█▋                                             | 8/217 [01:05<26:39,  7.65s/it]2026-02-01 16:18:44 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:06,  5.28it/s]
2026-02-01 16:18:50 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:50 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:18:50 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.76s/it]
Encoding molecules:   4%|█▉                                             | 9/217 [01:13<26:57,  7.78s/it]2026-02-01 16:18:52 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:06,  5.31it/s]
2026-02-01 16:18:58 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:18:58 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 100.00% of molecules.


2026-02-01 16:18:58 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.28s/it]
Encoding molecules:   5%|██                                            | 10/217 [01:21<26:37,  7.72s/it]2026-02-01 16:19:00 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:06,  4.67it/s]
2026-02-01 16:19:06 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 96.88% of molecules.


2026-02-01 16:19:07 | unimol_tools/data/conformer.py | 191 | INFO | Uni-Mol Tools | Failed conformers indices: [27]


2026-02-01 16:19:07 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 96.88% of molecules.


2026-02-01 16:19:07 | unimol_tools/data/conformer.py | 208 | INFO | Uni-Mol Tools | Failed 3d conformers indices: [27]


2026-02-01 16:19:07 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.17s/it]
Encoding molecules:   5%|██▎                                           | 11/217 [01:29<27:05,  7.89s/it]2026-02-01 16:19:08 | unimol_tools/data/conformer.py | 167 | INFO | Uni-Mol Tools | Start generating conformers...


32it [00:05,  5.76it/s]
2026-02-01 16:19:14 | unimol_tools/data/conformer.py | 182 | INFO | Uni-Mol Tools | Succeeded in generating conformers for 100.00% of molecules.


2026-02-01 16:19:14 | unimol_tools/data/conformer.py | 199 | INFO | Uni-Mol Tools | Succeeded in generating 3d conformers for 96.88% of molecules.


2026-02-01 16:19:14 | unimol_tools/data/conformer.py | 208 | INFO | Uni-Mol Tools | Failed 3d conformers indices: [31]


2026-02-01 16:19:14 | unimol_tools/tasks/trainer.py | 103 | INFO | Uni-Mol Tools | Using CPU.


  0%|                                                                             | 0/1 [00:01<?, ?it/s]
Encoding molecules:   5%|██▎                                           | 11/217 [01:36<30:15,  8.81s/it]


## Adapter Architecture and Training

Train a small MLP adapter to map Uni-Mol2 representations into PROTON's embedding space.

In [None]:
class MolecularAdapter(nn.Module):
    """MLP adapter to map Uni-Mol2 molecular representations to PROTON embedding space.

    Architecture: input -> Linear -> LayerNorm -> GELU -> Dropout -> ... -> output
    """

    def __init__(
        self,
        input_dim: int = 512,
        hidden_dim: int = 512,
        output_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
    ):
        """Initialize the adapter.

        Args:
            input_dim: Uni-Mol2 embedding dimension
            hidden_dim: Hidden layer dimension
            output_dim: PROTON embedding dimension
            num_layers: Number of hidden layers
            dropout: Dropout probability
        """
        super().__init__()

        layers = []

        # Input layer
        layers.extend([
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        ])

        # Hidden layers
        for _ in range(num_layers - 1):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
            ])

        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))

        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: Uni-Mol2 embeddings [batch, input_dim]

        Returns:
            Projected embeddings [batch, output_dim]
        """
        return self.layers(x)

Prepare training data.

In [None]:
# Split data into train/validation sets
X = unimol_embeddings.numpy()
y = valid_drug_embeddings.numpy()

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=conf.seed
)

_logger.info(f"Training samples: {len(X_train)}")
_logger.info(f"Validation samples: {len(X_val)}")

# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32)
X_val_t = torch.tensor(X_val, dtype=torch.float32)
y_val_t = torch.tensor(y_val, dtype=torch.float32)

# Create data loaders
train_dataset = TensorDataset(X_train_t, y_train_t)
val_dataset = TensorDataset(X_val_t, y_val_t)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

Training loop with early stopping.

In [None]:
def train_adapter(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int = 500,
    lr: float = 1e-3,
    patience: int = 20,
    device: torch.device = TORCH_DEVICE,
) -> dict:
    """Train the adapter with early stopping.

    Args:
        model: MolecularAdapter model
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Maximum number of epochs
        lr: Learning rate
        patience: Early stopping patience
        device: Torch device

    Returns:
        Training history dict
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10
    )
    criterion = nn.MSELoss()

    history = {"train_loss": [], "val_loss": [], "cosine_sim": []}
    best_val_loss = float("inf")
    best_state = None
    patience_counter = 0

    pbar = tqdm(range(num_epochs), desc="Training adapter")
    for epoch in pbar:
        # Training
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(X_batch)

        train_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)

                y_pred = model(X_batch)
                loss = criterion(y_pred, y_batch)
                val_loss += loss.item() * len(X_batch)

                all_preds.append(y_pred.cpu())
                all_targets.append(y_batch.cpu())

        val_loss /= len(val_loader.dataset)

        # Compute cosine similarity
        all_preds = torch.cat(all_preds, dim=0)
        all_targets = torch.cat(all_targets, dim=0)
        cosine_sim = F.cosine_similarity(all_preds, all_targets, dim=1).mean().item()

        # Update history
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["cosine_sim"].append(cosine_sim)

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Update progress bar
        pbar.set_postfix({
            "train": f"{train_loss:.4f}",
            "val": f"{val_loss:.4f}",
            "cos": f"{cosine_sim:.4f}",
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                _logger.info(f"Early stopping at epoch {epoch + 1}")
                break

    # Restore best model
    if best_state is not None:
        model.load_state_dict(best_state)

    return history

Train the adapter.

In [None]:
# Initialize adapter
adapter = MolecularAdapter(
    input_dim=unimol_embeddings.shape[1],
    hidden_dim=512,
    output_dim=EMBEDDING_DIM,
    num_layers=2,
    dropout=0.1,
)

_logger.info(f"Adapter parameters: {sum(p.numel() for p in adapter.parameters()):,}")

# Train
history = train_adapter(
    adapter,
    train_loader,
    val_loader,
    num_epochs=500,
    lr=1e-3,
    patience=20,
)

Save trained adapter.

In [None]:
# Save adapter
torch.save({
    "model_state_dict": adapter.state_dict(),
    "input_dim": unimol_embeddings.shape[1],
    "hidden_dim": 512,
    "output_dim": EMBEDDING_DIM,
    "num_layers": 2,
}, ADAPTER_PATH)
_logger.info(f"Saved adapter to {ADAPTER_PATH}")

Plot training curves.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
ax = axes[0]
ax.plot(history["train_loss"], label="Train", color="#1f77b4")
ax.plot(history["val_loss"], label="Validation", color="#ff7f0e")
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss")
ax.set_title("Training Loss")
ax.legend()
ax.grid(True, alpha=0.3)

# Cosine similarity
ax = axes[1]
ax.plot(history["cosine_sim"], color="#2ca02c")
ax.set_xlabel("Epoch")
ax.set_ylabel("Cosine Similarity")
ax.set_title("Validation Cosine Similarity")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "training_curves.pdf", bbox_inches="tight")
plt.show()

## Evaluation

Evaluate the adapter's ability to reconstruct PROTON embeddings from molecular structure.

In [None]:
# Compute adapted embeddings for all valid drugs
adapter.eval()
adapter = adapter.to(TORCH_DEVICE)

with torch.no_grad():
    adapted_embeddings = adapter(unimol_embeddings.to(TORCH_DEVICE)).cpu()

_logger.info(f"Adapted embeddings shape: {adapted_embeddings.shape}")

Compute per-drug cosine similarity.

In [None]:
# Compute cosine similarity for each drug
cosine_similarities = F.cosine_similarity(
    adapted_embeddings, valid_drug_embeddings, dim=1
).numpy()

# Add to dataframe
valid_drugs_df["cosine_similarity"] = cosine_similarities

# Statistics
_logger.info(f"Mean cosine similarity: {cosine_similarities.mean():.4f}")
_logger.info(f"Std cosine similarity: {cosine_similarities.std():.4f}")
_logger.info(f"Min cosine similarity: {cosine_similarities.min():.4f}")
_logger.info(f"Max cosine similarity: {cosine_similarities.max():.4f}")

Plot cosine similarity distribution.

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

ax.hist(cosine_similarities, bins=50, color="#1f77b4", edgecolor="black", alpha=0.7)
ax.axvline(cosine_similarities.mean(), color="red", linestyle="--", 
           label=f"Mean: {cosine_similarities.mean():.3f}")
ax.set_xlabel("Cosine Similarity")
ax.set_ylabel("Count")
ax.set_title("Distribution of Adapted vs. Original Embedding Similarity")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "cosine_similarity_distribution.pdf", bbox_inches="tight")
plt.show()

Identify best and worst reconstructed drugs.

In [None]:
# Best reconstructed drugs
_logger.info("Top 10 best reconstructed drugs:")
best_drugs = valid_drugs_df.nlargest(10, "cosine_similarity")[["node_name", "cosine_similarity"]]
display(best_drugs)

# Worst reconstructed drugs
_logger.info("Top 10 worst reconstructed drugs:")
worst_drugs = valid_drugs_df.nsmallest(10, "cosine_similarity")[["node_name", "cosine_similarity"]]
display(worst_drugs)

## De Novo Molecule Inference

Use the trained adapter to project new molecules into PROTON's embedding space and score them against diseases.

In [None]:
def load_adapter(checkpoint_path: Path, device: torch.device = TORCH_DEVICE) -> MolecularAdapter:
    """Load a trained adapter from checkpoint.

    Args:
        checkpoint_path: Path to adapter checkpoint
        device: Torch device

    Returns:
        Loaded MolecularAdapter model
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    adapter = MolecularAdapter(
        input_dim=checkpoint["input_dim"],
        hidden_dim=checkpoint["hidden_dim"],
        output_dim=checkpoint["output_dim"],
        num_layers=checkpoint["num_layers"],
    )
    adapter.load_state_dict(checkpoint["model_state_dict"])
    adapter = adapter.to(device)
    adapter.eval()
    return adapter


def project_denovo_molecule(
    smiles: str,
    adapter: MolecularAdapter,
    device: torch.device = TORCH_DEVICE,
) -> torch.Tensor:
    """Project a de novo molecule into PROTON's embedding space.

    Args:
        smiles: SMILES string of the molecule
        adapter: Trained MolecularAdapter
        device: Torch device

    Returns:
        Embedding tensor [1, 512]
    """
    # Import Uni-Mol2
    from unimol_tools import UniMolRepr

    # Encode with Uni-Mol2
    clf = UniMolRepr(data_type="molecule", remove_hs=False)
    reprs = clf.get_repr([smiles], return_atomic_reprs=False)
    mol_embedding = torch.tensor(reprs["cls_repr"], dtype=torch.float32).to(device)

    # Project through adapter
    with torch.no_grad():
        projected = adapter(mol_embedding)

    return projected


def score_molecule_disease(
    smiles: str,
    disease_index: int,
    adapter: MolecularAdapter,
    model: HGT,
    kg: object,
    embeddings: torch.Tensor,
    edge_type: tuple = ("disease", "indication", "drug"),
) -> float:
    """Score a de novo molecule against a disease.

    Args:
        smiles: SMILES string of the molecule
        disease_index: Node index of the disease
        adapter: Trained MolecularAdapter
        model: PROTON HGT model
        kg: Knowledge graph
        embeddings: Pre-computed PROTON embeddings
        edge_type: Edge type tuple for scoring

    Returns:
        Indication score (higher = more likely)
    """
    # Get molecule embedding
    mol_embedding = project_denovo_molecule(smiles, adapter)

    # Create a modified embeddings tensor with the new molecule
    # We'll use a placeholder index (the first drug index)
    drug_nodes = nodes[nodes["node_type"] == "drug"]
    placeholder_idx = drug_nodes["node_index"].iloc[0]

    # Clone embeddings and replace placeholder
    modified_embeddings = embeddings.clone()
    modified_embeddings[placeholder_idx] = mol_embedding.squeeze().cpu()

    # Get decoder weights
    decoder = model.decoder.W.detach().cpu()

    # Score
    scores = model.get_scores_from_embeddings(
        src_ids=[disease_index],
        dst_ids=[placeholder_idx],
        query_edge_type=edge_type,
        embeddings=modified_embeddings,
        decoder=decoder,
        query_kg=kg,
        use_cache=False,
    )

    return scores[0].item()

### Example: Score Novel Molecules Against Parkinson's Disease

In [None]:
# Find Parkinson's disease index
pd_node = nodes[nodes["node_name"] == "Parkinson disease"]
pd_index = pd_node["node_index"].values[0]
_logger.info(f"Parkinson's disease index: {pd_index}")

In [None]:
# Example: Score Levodopa (known PD drug) as a sanity check
# Levodopa SMILES: NC(Cc1ccc(O)c(O)c1)C(=O)O
levodopa_smiles = "NC(Cc1ccc(O)c(O)c1)C(=O)O"

# Score against Parkinson's disease
score = score_molecule_disease(
    smiles=levodopa_smiles,
    disease_index=pd_index,
    adapter=adapter,
    model=pretrain_model,
    kg=kg,
    embeddings=embeddings,
)

_logger.info(f"Levodopa indication score for Parkinson's disease: {score:.4f}")

Compare with existing drug scores for reference.

In [None]:
# Get decoder
decoder = pretrain_model.decoder.W.detach().cpu()

# Score all existing drugs against Parkinson's disease
all_drug_indices = drug_nodes["node_index"].values
pd_indices = [pd_index] * len(all_drug_indices)

all_scores = pretrain_model.get_scores_from_embeddings(
    src_ids=pd_indices,
    dst_ids=all_drug_indices.tolist(),
    query_edge_type=("disease", "indication", "drug"),
    embeddings=embeddings,
    decoder=decoder,
    query_kg=kg,
    use_cache=False,
)

# Statistics
all_scores_np = all_scores.cpu().numpy()
_logger.info(f"Existing drug scores - Mean: {all_scores_np.mean():.4f}, Std: {all_scores_np.std():.4f}")
_logger.info(f"Existing drug scores - Min: {all_scores_np.min():.4f}, Max: {all_scores_np.max():.4f}")

### Batch Scoring for Multiple Molecules

In [None]:
def score_molecules_batch(
    smiles_list: list[str],
    disease_index: int,
    adapter: MolecularAdapter,
    model: HGT,
    kg: object,
    embeddings: torch.Tensor,
    device: torch.device = TORCH_DEVICE,
) -> pd.DataFrame:
    """Score multiple molecules against a disease.

    Args:
        smiles_list: List of SMILES strings
        disease_index: Node index of the disease
        adapter: Trained MolecularAdapter
        model: PROTON HGT model
        kg: Knowledge graph
        embeddings: Pre-computed PROTON embeddings
        device: Torch device

    Returns:
        DataFrame with SMILES and scores
    """
    # Encode all molecules with Uni-Mol2
    from unimol_tools import UniMolRepr

    clf = UniMolRepr(data_type="molecule", remove_hs=False)
    reprs = clf.get_repr(smiles_list, return_atomic_reprs=False)
    mol_embeddings = torch.tensor(reprs["cls_repr"], dtype=torch.float32).to(device)

    # Project through adapter
    with torch.no_grad():
        projected = adapter(mol_embeddings).cpu()

    # Score each molecule
    scores = []
    drug_nodes_local = nodes[nodes["node_type"] == "drug"]
    placeholder_idx = drug_nodes_local["node_index"].iloc[0]
    decoder = model.decoder.W.detach().cpu()

    for i, emb in enumerate(projected):
        modified_embeddings = embeddings.clone()
        modified_embeddings[placeholder_idx] = emb

        score = model.get_scores_from_embeddings(
            src_ids=[disease_index],
            dst_ids=[placeholder_idx],
            query_edge_type=("disease", "indication", "drug"),
            embeddings=modified_embeddings,
            decoder=decoder,
            query_kg=kg,
            use_cache=False,
        )
        scores.append(score[0].item())

    return pd.DataFrame({"SMILES": smiles_list, "score": scores})

Example: Score a set of candidate molecules.

In [None]:
# Example candidate molecules (known PD drugs and random molecules)
candidate_smiles = [
    "NC(Cc1ccc(O)c(O)c1)C(=O)O",  # Levodopa
    "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O",  # Ibuprofen (not PD drug)
    "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",  # Caffeine
    "CC(=O)OC1=CC=CC=C1C(=O)O",  # Aspirin
]

candidate_names = ["Levodopa", "Ibuprofen", "Caffeine", "Aspirin"]

# Score candidates
results = score_molecules_batch(
    smiles_list=candidate_smiles,
    disease_index=pd_index,
    adapter=adapter,
    model=pretrain_model,
    kg=kg,
    embeddings=embeddings,
)

results["name"] = candidate_names
results = results[["name", "SMILES", "score"]].sort_values("score", ascending=False)

_logger.info("Candidate molecule scores for Parkinson's disease:")
display(results)

## Summary

This notebook demonstrated how to:

1. **Collect SMILES data** for existing drugs in NeuroKG from PubChem
2. **Encode molecules** using Uni-Mol2's pre-trained molecular encoder
3. **Train an adapter** to map Uni-Mol2 representations to PROTON's embedding space
4. **Project de novo molecules** into PROTON's space and score them against diseases

### Key Outputs

| File | Description |
|------|-------------|
| `data/neurokg/drugs/drug_smiles.csv` | Cached SMILES for NeuroKG drugs |
| `data/neurokg/drugs/unimol_embeddings.pt` | Cached Uni-Mol2 embeddings |
| `data/checkpoints/molecular_adapter.pt` | Trained adapter weights |

### Usage

To score a new molecule:

```python
# Load adapter
adapter = load_adapter(ADAPTER_PATH)

# Score molecule against disease
score = score_molecule_disease(
    smiles="YOUR_SMILES_HERE",
    disease_index=disease_index,
    adapter=adapter,
    model=pretrain_model,
    kg=kg,
    embeddings=embeddings,
)
```