In [None]:
# ! pip install orb-models
# ! pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann#egg=af434039ae14bedcbb838a7808924d6689274168",
# ! pip install rdkit==2023.9.4
# ! pip install xyz2graph
# ! pip install nglview
# ! pip install packaging==24.1
# ! pip install transformers==4.46.3

In [None]:
# ! pip install --upgrade torch_geometric
# ! pip install ase==3.24.0
# ! pip install torch_nl==0.3
# ! pip install rdkit
# ! pip install PyTDC

In [None]:
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import os, sys, gc
import pandas as pd
from rdkit.Chem import AllChem
import numpy as np
from rdkit import Chem
import ase
from ase import Atoms

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import orb_models.utils as utils
from orb_models.forcefield.base import AtomGraphs
from orb_models.dataset.augmentations import rotate_randomly
from orb_models.dataset.base_datasets import AtomsDataset
sys.path.append('../external_repos/')
from orb_models_modified.orb_models.forcefield import base, pretrained, atomic_system, property_definitions

sys.path.append('../')
import utils as tdc_utils

from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool

from tdc.single_pred import ADME, Tox

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.float32
compile = None

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32  = True

In [None]:
import hashlib

def smiles_to_xyz_file(smiles: str, output_dir: str = 'xyz_files/', random_seed: int = 42,
                      jitter_amp: float = 1e-2) -> str | None:

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Convert SMILES to 3D structure with hydrogens (for better geometry)
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None

    mol = Chem.AddHs(mol)

    # Try different embedding methods in order of preference
    conformer_generated = False
    
    # Method 1: Standard embedding with UFF optimization
    try:
        embed_result = AllChem.EmbedMolecule(mol, maxAttempts=2000, randomSeed=random_seed, useRandomCoords=True)
        if embed_result == 0:  # Success
            # UFF optimization
            uff_props = AllChem.UFFGetMoleculeForceField(mol)
            if uff_props is not None:
                uff_props.Initialize()
                uff_props.Minimize(maxIts=1000)

                # Add jitter to avoid symmetry issues
                conf = mol.GetConformer()
                pos = conf.GetPositions()
                np.random.seed(random_seed)
                noise = (np.random.rand(*pos.shape) - 0.5) * jitter_amp
                new_pos = pos + noise
                for i, p in enumerate(new_pos):
                    conf.SetAtomPosition(i, p.tolist())

                # Final MMFF optimization
                AllChem.MMFFOptimizeMolecule(mol)
                conformer_generated = True
    except Exception as e:
        print(f"Method 1 failed: {e}")
        pass
    
    # Method 2: Try ETKDGv3 if method 1 failed
    if not conformer_generated:
        try:
            embed_result = AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
            if embed_result == 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 2 failed: {e}")
            pass
    
    # Method 3: Try explicit conformer creation with 2D coords
    if not conformer_generated:
        try:
            # Create a fresh conformer explicitly
            conf = Chem.Conformer(mol.GetNumAtoms())
            for i in range(mol.GetNumAtoms()):
                conf.SetAtomPosition(i, [0.0, 0.0, 0.0])  # Initialize with zeros
            conf_id = mol.AddConformer(conf)
            
            # Now compute 2D coords
            AllChem.Compute2DCoords(mol)
            
            # Check if the conformer exists and has non-zero coordinates
            if mol.GetNumConformers() > 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 3 failed: {e}")
            pass

    # Bail out if we couldn't generate any conformer
    if not conformer_generated or mol.GetNumConformers() == 0:
        return None

    # Get the conformer
    conf = mol.GetConformer()

    # 5) Check for overlaps & retry if needed
    def has_overlaps(pos, eps=1e-6):
        d = torch.cdist(torch.tensor(pos), torch.tensor(pos))
        n = d.shape[0]
        d[range(n), range(n)] = float('inf')
        return (d <= eps).any().item()

    pos = conf.GetPositions()
    if has_overlaps(pos, eps=1e-6):
        # you could loop a few times here with new seeds,
        # or simply bail out and let your dataset drop it:
        return None
    
    # Extract heavy atoms only
    heavy_atoms = []
    for atom in mol.GetAtoms():
        if atom.GetSymbol() != 'H':
            idx = atom.GetIdx()
            try:
                pos = conf.GetAtomPosition(idx)
                heavy_atoms.append((atom.GetSymbol(), pos))
            except Exception:
                # Skip this atom if position can't be retrieved
                continue

    # If no heavy atoms were successfully processed, fail
    if not heavy_atoms:
        return None

    # Build filename
    h = hashlib.md5(smiles.encode()).hexdigest()[:10]
    filename = f"{h}.xyz"
    path = os.path.join(output_dir, filename)

    # Write out XYZ
    with open(path, 'w') as f:
        f.write(f"{len(heavy_atoms)}\n")
        f.write(f"{smiles}\n")
        for sym, pt in heavy_atoms:
            f.write(f"{sym} {pt.x:.6f} {pt.y:.6f} {pt.z:.6f}\n")

    return path

In [None]:
def canonicalize_smiles(smiles):
    """Converts a SMILES string to its canonical form."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None  # or raise an exception, depending on your needs
        return Chem.MolToSmiles(mol, canonical=True)
    except:
        return None

In [None]:
def get_smiles(input_file):
    # Read the entire content of the XYZ file.
    with open(input_file, 'r') as f:
        xyz_data = f.read()

    # Split the content into lines.
    lines = xyz_data.splitlines()
    smiles = lines[1].strip()

    return smiles #canonicalize_smiles(smiles)

In [None]:
class XYZFolderDataset(AtomsDataset):
    """
    A Dataset that mimics AseSqliteDataset but reads single .xyz files
    based on metal_smiles from a DataFrame.
    """

    def __init__(
        self,
        df,                        # your pandas DataFrame
        xyz_dir: Union[str, Path],
        system_config: atomic_system.SystemConfig,
        target_col: str,
        target_config: Optional[property_definitions.PropertyConfig] = None,
        augmentations: Optional[List[Callable[[ase.Atoms], None]]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(name="xyz_folder", system_config=system_config, augmentations=augmentations)
        self.df = df.reset_index(drop=True)
        self.xyz_dir = Path(xyz_dir)
        self.target_col = target_col
        self.dtype = dtype or torch.get_default_dtype()
        self.target_config = target_config or property_definitions.PropertyConfig()
        self.constraints: List[Callable] = []

        # build smiles → path lookup
        self.xyz_dir = Path(xyz_dir)
        lookup = {}
        for p in self.xyz_dir.glob("*.xyz"):
            s = get_smiles(str(p))
            lookup[s] = p
        self._lookup = lookup

        df = df.reset_index(drop=True)
        mask = df['smiles'].isin(lookup)
        num_bad = (~mask).sum()
        if num_bad > 0:
            print(f"Warning: dropping {num_bad} rows with no matching .xyz file")
        self.df = df.loc[mask].reset_index(drop=True)

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[AtomGraphs, torch.Tensor]:
        row = self.df.iloc[idx]
        smiles = row['smiles']
        target = row[self.target_col]

        xyz_path = self._lookup.get(smiles)
        if xyz_path is None:
            raise KeyError(f"Could not find .xyz for SMILES `{smiles}` in {self.xyz_dir}")

        # read Atoms
        atoms = ase.io.read(str(xyz_path))
        # extract targets into atoms.info
        atoms.info = {}
        atoms.info.update(self.target_config.extract(row.to_dict(), self.name, "targets"))

        # augmentations
        for aug in self.augmentations or []:
            aug(atoms)

        # constraints
        atoms.set_constraint(None)
        for c in self.constraints:
            c(atoms, {}, self.name)

        # to graph
        graph = atomic_system.ase_atoms_to_atom_graphs(
            atoms=atoms,
            system_config=self.system_config,
            edge_method="knn_scipy",
            wrap=True,
            system_id=idx,
            output_dtype=self.dtype,
            graph_construction_dtype=self.dtype,
        )

        # return graph + scalar target
        return graph, torch.tensor(target, dtype=self.dtype)

    def get_atom(self, idx: int) -> ase.Atoms:
        """Return the raw ASE Atoms for this index."""
        row = self.df.iloc[idx]
        xyz_path = self._lookup[row['smiles']]
        return ase.io.read(str(xyz_path))

    def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]:
        """Return both ASE Atoms and the row’s metadata dict."""
        row = self.df.iloc[idx]
        xyz_path = self._lookup[row['smiles']]
        atoms = ase.io.read(str(xyz_path))
        return atoms, row.to_dict()

In [None]:
def collate_graph_and_targets(batch):
    # batch is a list of (AtomGraphs, tensor) tuples
    graphs, targets = zip(*batch)
    batched_graph = base.batch_graphs(graphs)
    batched_targets = torch.stack(targets, dim=0)
    return batched_graph, batched_targets

In [None]:
def make_train_loader(
    df,
    xyz_dir: str,
    system_config,
    target_col: str,
    batch_size: int,
    num_workers: int,
    target_config=None,
    augmentations=None,
    dtype=None,
):
    dataset = XYZFolderDataset(
        df=df,
        xyz_dir=xyz_dir,
        system_config=system_config,
        target_col=target_col,
        target_config=target_config,
        augmentations=augmentations,
        dtype=dtype,
    )

    # sampler = RandomSampler(dataset)
    # batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=True)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        worker_init_fn=utils.worker_init_fn,
        collate_fn=collate_graph_and_targets,
        pin_memory=True,
    )
    return loader

def make_val_loader(
    df,
    xyz_dir: str,
    system_config,
    target_col: str,
    batch_size: int,
    num_workers: int,
    target_config=None,
    augmentations=None,
    dtype=None,
):
    dataset = XYZFolderDataset(
        df=df,
        xyz_dir=xyz_dir,
        system_config=system_config,
        target_col=target_col,
        target_config=target_config,
        augmentations=augmentations,
        dtype=dtype,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        worker_init_fn=utils.worker_init_fn,
        collate_fn=collate_graph_and_targets,
        pin_memory=True,
    )
    return loader

In [None]:
class Model(nn.Module):
    def __init__(self,
                reg_drop_rate=0.1,
                reg_size=256,
                num_labels=1):

        super(Model, self).__init__()
        self.reg_drop_rate = reg_drop_rate
        self.num_targets = num_labels
        self.reg_size = reg_size

        # self.gnn = pretrained.orb_v3_direct_inf_mpa(
        # device='cpu',
        # precision="float32-highest",   # or "float32-highest" / "float64
        # )
        self.gnn = pretrained.orb_v3_conservative_inf_omat(
        device='cpu',
        precision="float32-highest",   # or "float32-highest" / "float64
        )

    def forward(self, batched_graph, layer_idx):
        node_emb = self.gnn(batched_graph)
        node_emb = node_emb["intermediate_layers"][layer_idx]
        batch_index = batched_graph._get_per_node_graph_indices().long()
        graph_emb = global_mean_pool(node_emb, batch_index)
        return graph_emb

In [None]:
model = Model().to('cuda')
model.eval();

In [None]:
# skip = ['solubility_aqsoldb', 'ppbr_az',
#        'herg', 'cyp2d6_substrate_carbonmangels',
#        'bbb_martins']

In [None]:
torch.autograd.set_detect_anomaly(True)

base_dir = 'tdcommons/admet_group/'
dfs = []

for task in os.listdir(base_dir):
    if task.startswith('.'):
        continue

    print(task)

    task_dir = os.path.join(base_dir, task)

    # if task in skip: continue

    prefix = 'tdcommons/'
    if prefix+task in tdc_utils.tdc_mae_tasks:
        metric = 'mae'
    elif prefix+task in tdc_utils.tdc_spearman_task:
        metric = 'spearman'
    elif prefix+task in tdc_utils.polaris_pearson_tasks:
        metric = 'pearson'
    elif prefix+task in tdc_utils.tdc_auroc_tasks:
        metric = 'auc'
    elif prefix+task in tdc_utils.tdc_aucpr_tasks:
        metric = 'aucpr'
    elif prefix+task in tdc_utils.tdc_aucpr2_tasks:
        metric = 'aucpr'
    elif prefix+task in tdc_utils.polaris_aucpr_tasks:
        metric = 'aucpr'
    else:
        raise ValueError(f"Task {task} not found in any known task list.")

    try:
        data = ADME(name = task)
    except:
        data = Tox(name = task)

    split = data.get_split(method = 'scaffold')

    train_df = split['train'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    val_df = split['valid'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    test_df = split['test'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)

    train_dataloader = make_train_loader(
        df=train_df,
        xyz_dir=f'xyz_files/{task}/train/graphs',
        system_config=model.gnn.system_config,
        target_col='target',
        batch_size=64,
        num_workers=4,
        target_config=None,
        augmentations=[rotate_randomly],
        dtype=torch.float32,
    )

    test_dataloader = make_val_loader(
        df=test_df,
        xyz_dir=f'xyz_files/{task}/test/graphs',
        system_config=model.gnn.system_config,
        target_col='target',
        batch_size=256,
        num_workers=1,
        target_config=None,
        augmentations=None,
        dtype=torch.float32,
    )

    num_layers = 5
    layer_indices = list(range(0, num_layers))

    for df, split, loader in zip([train_df, test_df], ['train', 'test'], [train_dataloader, test_dataloader]):
        for layer_idx in layer_indices:
            all_embeddings = []
            for step, (mol_graphs, targets) in enumerate(loader):
                mol_graphs = mol_graphs.to(device)
                embeddings = model(mol_graphs, layer_idx)
                embeddings = embeddings.detach().cpu().numpy()
                all_embeddings.append(embeddings)

            stacked = np.vstack(all_embeddings)

            n_samples, emb_dim = stacked.shape
            col_names = [f"orb_conserv_layer{layer_idx}_{i}" for i in range(emb_dim)]
            out_df = pd.DataFrame(stacked, columns=col_names)
            combined = pd.concat([df, out_df], axis=1)
            combined.columns = list(df.columns) + list(out_df.columns)

            n = f"{split}_orbconserv_layer{layer_idx}_emb.csv"
            combined.to_csv(os.path.join(task_dir, n), index=False)