In [None]:
# ! pip install --upgrade torch_geometric
# ! pip install ase==3.24.0
# ! pip install torch_nl==0.3
# ! pip install rdkit
# ! pip install PyTDC

In [None]:
import os
import pandas as pd
from rdkit.Chem import AllChem
import numpy as np
from rdkit import Chem
from sklearn.preprocessing import StandardScaler

import sys
sys.path.append('../external_repos/')
from posegnn.model import PosEGNN

from ase import Atoms as ASEAtoms

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool

from tdc.single_pred import ADME, Tox

sys.path.append('../')
import utils

device = 'cuda'

In [None]:
def smiles_to_ase_atoms(smiles: str, random_seed: int = 42,
                        jitter_amp: float = 1e-2) -> ASEAtoms | None:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    mol = Chem.AddHs(mol)
    
    # Try different embedding methods in order of preference
    conformer_generated = False
    
    # Method 1: Standard embedding with UFF optimization
    try:
        embed_result = AllChem.EmbedMolecule(mol, maxAttempts=2000, randomSeed=random_seed, useRandomCoords=True)
        if embed_result == 0:  # Success
            # UFF optimization
            uff_props = AllChem.UFFGetMoleculeForceField(mol)
            if uff_props is not None:
                uff_props.Initialize()
                uff_props.Minimize(maxIts=1000)
                
                # Add jitter to avoid symmetry issues
                conf = mol.GetConformer()
                pos = conf.GetPositions()
                np.random.seed(random_seed)
                noise = (np.random.rand(*pos.shape) - 0.5) * jitter_amp
                new_pos = pos + noise
                for i, p in enumerate(new_pos):
                    conf.SetAtomPosition(i, p.tolist())
                
                # Final MMFF optimization
                AllChem.MMFFOptimizeMolecule(mol)
                conformer_generated = True
    except Exception as e:
        print(f"Method 1 failed: {e}")
        pass
    
    # Method 2: Try ETKDGv3 if method 1 failed
    if not conformer_generated:
        try:
            embed_result = AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
            if embed_result == 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 2 failed: {e}")
            pass
    
    # Method 3: Try explicit conformer creation with 2D coords
    if not conformer_generated:
        try:
            # Create a fresh conformer explicitly
            conf = Chem.Conformer(mol.GetNumAtoms())
            for i in range(mol.GetNumAtoms()):
                conf.SetAtomPosition(i, [0.0, 0.0, 0.0])  # Initialize with zeros
            conf_id = mol.AddConformer(conf)
            
            # Now compute 2D coords
            AllChem.Compute2DCoords(mol)
            
            # Check if the conformer exists and has non-zero coordinates
            if mol.GetNumConformers() > 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 3 failed: {e}")
            pass

    # Bail out if we couldn't generate any conformer
    if not conformer_generated or mol.GetNumConformers() == 0:
        return None

    # Get the conformer
    conf = mol.GetConformer()

    def has_overlaps(pos, eps=1e-6):
        d = torch.cdist(torch.tensor(pos), torch.tensor(pos))
        n = d.shape[0]
        d[range(n), range(n)] = float('inf')
        return (d <= eps).any().item()

    pos = conf.GetPositions()
    if has_overlaps(pos, eps=1e-6):
        # you could loop a few times here with new seeds,
        # or simply bail out and let your dataset drop it:
        return None
    
    # Extract heavy atoms only
    heavy_atoms = []
    for atom in mol.GetAtoms():
        if atom.GetSymbol() != 'H':
            idx = atom.GetIdx()
            try:
                pos = conf.GetAtomPosition(idx)
                heavy_atoms.append((atom.GetSymbol(), pos))
            except Exception:
                # Skip this atom if position can't be retrieved
                continue

    # If no heavy atoms were successfully processed, fail
    if not heavy_atoms:
        return None

    # Final check
    mol = Chem.RemoveHs(mol)
    conf = mol.GetConformer()
    pos = conf.GetPositions()
    # if np.allclose(pos, 0) or np.allclose(pos[:, 2], 0) or np.isnan(pos).any():
    #     return None

    return ASEAtoms(
        symbols=[a.GetSymbol() for a in mol.GetAtoms()],
        positions=pos,
        pbc=False
    )

In [None]:
def build_data_from_ase(atoms: ASEAtoms) -> Data:
    z = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long)
    box = torch.tensor(atoms.get_cell().tolist()).unsqueeze(0).float()
    pos = torch.tensor(atoms.get_positions().tolist()).float()
    batch = torch.zeros(len(z), dtype=torch.long)
    return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1)

In [None]:
class Model(nn.Module):
    def __init__(self,
                reg_drop_rate=0.1,
                intermediate_size=256,
                num_targets=1):

        super(Model, self).__init__()
        self.reg_drop_rate = reg_drop_rate
        self.num_targets = num_targets
        self.intermediate_size = intermediate_size

        self.hidden_size = 256
        checkpoint_dict = torch.load('../external_repos/posegnn/pytorch_model.bin', weights_only=True, map_location=device)
        self.transformer = PosEGNN(checkpoint_dict["config"])
        self.transformer.load_state_dict(checkpoint_dict["state_dict"], strict=False)

    def forward(self, batch, layer_idx=-1):
        with torch.no_grad():
            output_dict = self.transformer(batch)
            node_emb = output_dict["embedding_0"]
            if node_emb.dim() == 3:
                node_emb = node_emb[:, :, layer_idx]  # -> [total_nodes, hidden_dim]
            # `batch.batch` is the graph‐index for each node
            graph_emb = global_mean_pool(node_emb, batch.batch) # -> [batch_size, hidden_dim]
        return graph_emb

In [None]:
skip = []

In [None]:
torch.autograd.set_detect_anomaly(True)

base_dir = '../input_data/tdcommons/admet_group'
dfs = []

for task in os.listdir(base_dir):
    if task.startswith('.'): # remove .ipy file
        continue

    print(task)

    if task in skip:
        continue

    prefix = 'tdcommons/'
    if prefix+task in utils.tdc_mae_tasks:
        metric = 'mae'
    elif prefix+task in utils.tdc_spearman_task:
        metric = 'spearman'
    elif prefix+task in utils.polaris_pearson_tasks:
        metric = 'pearson'
    elif prefix+task in utils.tdc_auroc_tasks:
        metric = 'auc'
    elif prefix+task in utils.tdc_aucpr_tasks:
        metric = 'aucpr'
    elif prefix+task in utils.tdc_aucpr2_tasks:
        metric = 'aucpr'
    elif prefix+task in utils.polaris_aucpr_tasks:
        metric = 'aucpr'
    else:
        raise ValueError(f"Task {task} not found in any known task list.")

    try:
        data = ADME(name = task)
    except:
        data = Tox(name = task)

    split = data.get_split(method = 'scaffold')

    train_df = split['train'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    val_df = split['valid'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    test_df = split['test'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)

    if metric in ('mae', 'spearman', 'pearson'):
        scaler = StandardScaler()
        # fit only on train targets
        train_vals = train_df[['target']].values
        scaler.fit(train_vals)
        # add scaled targets
        train_df['target'] = scaler.transform(train_vals)
        val_df['target']   = scaler.transform(val_df[['target']].values)
        test_df['target']  = scaler.transform(test_df[['target']].values)

    for df, split in zip([train_df, val_df, test_df], ['train','val','test']):
        out_dir = f'3D_data/{task}/{split}/graphs'

        # If the folder already exists *and* has at least one .pt file, skip it
        if os.path.isdir(out_dir) and any(f.endswith('.pt') for f in os.listdir(out_dir)):
            continue

        os.makedirs(out_dir, exist_ok=True)

        for idx, row in df.iterrows():
            atoms = smiles_to_ase_atoms(row.smiles)
            if atoms is None:
                continue
            data = build_data_from_ase(atoms) 
            assert data.num_nodes > 0, f"no atoms in idx={idx}"
            data.y = torch.tensor([row.target], dtype=torch.float32)

            if torch.isnan(data.pos).any() or torch.isinf(data.pos).any():
                print(f"WARNING: NaN or Inf in positions for SMILES: {row.smiles}, idx: {idx}")
                # Decide how to handle: skip, log, etc.
                continue
            if torch.isnan(data.y).any() or torch.isinf(data.y).any():
                print(f"WARNING: NaN or Inf in target for SMILES: {row.smiles}, idx: {idx}")
                # This might indicate issues with the scaler or original data
                continue

            if not torch.isfinite(data.y).all():
                print(f"Skipping idx={idx} due to NaN target")
                continue

            torch.save(data, os.path.join(out_dir, f'{idx}.pt'))
        print(f"✔ Finished dumping {len(os.listdir(out_dir))} graphs for {task}/{split}")