In [None]:
# ! pip install --upgrade torch_geometric
# ! pip install ase==3.24.0
# ! pip install torch_nl==0.3
# ! pip install rdkit
# ! pip install PyTDC

In [None]:
import gc
import os
import glob, copy

import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem

from sklearn.metrics import mean_absolute_error, roc_auc_score, average_precision_score
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr, pearsonr

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import AdamW
from torch.nn import MSELoss, BCEWithLogitsLoss

from transformers import get_linear_schedule_with_warmup

from ase import Atoms as ASEAtoms

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

from tdc.single_pred import ADME, Tox



import sys
sys.path.append('../external_repos/')
from posegnn.model import PosEGNN

sys.path.append('../')
import utils

device = 'cuda'

In [None]:
def smiles_to_ase_atoms(smiles: str, random_seed: int = 42,
                        jitter_amp: float = 1e-2) -> ASEAtoms | None:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    mol = Chem.AddHs(mol)
    
    # Try different embedding methods in order of preference
    conformer_generated = False
    
    # Method 1: Standard embedding with UFF optimization
    try:
        embed_result = AllChem.EmbedMolecule(mol, maxAttempts=2000, randomSeed=random_seed, useRandomCoords=True)
        if embed_result == 0:  # Success
            # UFF optimization
            uff_props = AllChem.UFFGetMoleculeForceField(mol)
            if uff_props is not None:
                uff_props.Initialize()
                uff_props.Minimize(maxIts=1000)
                
                # Add jitter to avoid symmetry issues
                conf = mol.GetConformer()
                pos = conf.GetPositions()
                np.random.seed(random_seed)
                noise = (np.random.rand(*pos.shape) - 0.5) * jitter_amp
                new_pos = pos + noise
                for i, p in enumerate(new_pos):
                    conf.SetAtomPosition(i, p.tolist())
                
                # Final MMFF optimization
                AllChem.MMFFOptimizeMolecule(mol)
                conformer_generated = True
    except Exception as e:
        print(f"Method 1 failed: {e}")
        pass
    
    # Method 2: Try ETKDGv3 if method 1 failed
    if not conformer_generated:
        try:
            embed_result = AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
            if embed_result == 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 2 failed: {e}")
            pass
    
    # Method 3: Try explicit conformer creation with 2D coords
    if not conformer_generated:
        try:
            # Create a fresh conformer explicitly
            conf = Chem.Conformer(mol.GetNumAtoms())
            for i in range(mol.GetNumAtoms()):
                conf.SetAtomPosition(i, [0.0, 0.0, 0.0])  # Initialize with zeros
            conf_id = mol.AddConformer(conf)
            
            # Now compute 2D coords
            AllChem.Compute2DCoords(mol)
            
            # Check if the conformer exists and has non-zero coordinates
            if mol.GetNumConformers() > 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 3 failed: {e}")
            pass

    # Bail out if we couldn't generate any conformer
    if not conformer_generated or mol.GetNumConformers() == 0:
        return None

    # Get the conformer
    conf = mol.GetConformer()

    def has_overlaps(pos, eps=1e-6):
        d = torch.cdist(torch.tensor(pos), torch.tensor(pos))
        n = d.shape[0]
        d[range(n), range(n)] = float('inf')
        return (d <= eps).any().item()

    pos = conf.GetPositions()
    if has_overlaps(pos, eps=1e-6):
        # you could loop a few times here with new seeds,
        # or simply bail out and let your dataset drop it:
        return None
    
    # Extract heavy atoms only
    heavy_atoms = []
    for atom in mol.GetAtoms():
        if atom.GetSymbol() != 'H':
            idx = atom.GetIdx()
            try:
                pos = conf.GetAtomPosition(idx)
                heavy_atoms.append((atom.GetSymbol(), pos))
            except Exception:
                # Skip this atom if position can't be retrieved
                continue

    # If no heavy atoms were successfully processed, fail
    if not heavy_atoms:
        return None

    # Final check
    mol = Chem.RemoveHs(mol)
    conf = mol.GetConformer()
    pos = conf.GetPositions()
    # if np.allclose(pos, 0) or np.allclose(pos[:, 2], 0) or np.isnan(pos).any():
    #     return None

    return ASEAtoms(
        symbols=[a.GetSymbol() for a in mol.GetAtoms()],
        positions=pos,
        pbc=False
    )

In [None]:
def build_data_from_ase(atoms: ASEAtoms) -> Data:
    z = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long)
    box = torch.tensor(atoms.get_cell().tolist()).unsqueeze(0).float()
    pos = torch.tensor(atoms.get_positions().tolist()).float()
    batch = torch.zeros(len(z), dtype=torch.long)
    return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1)

In [None]:
class FileBacked3DDataset(Dataset):
    def __init__(self, root_folder: str):
        # collects e.g. ['.../0.pt','.../1.pt', ...]
        self.paths = sorted(glob.glob(f'{root_folder}/*.pt'))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx: int) -> Data:
        # load just one graph
        data: Data = torch.load(self.paths[idx])
        data.idx = idx
        # Ensure all tensors are on CPU and detached
        for key, value in data:
            if isinstance(value, torch.Tensor) and torch.isnan(value).any():
                raise ValueError(f"🛑 NaN detected in '{key}'")
            if isinstance(value, torch.Tensor):
                data[key] = value.cpu().detach()
        return data

def collate_fn(batch_list):
    return Batch.from_data_list(batch_list)

In [None]:
class Transformer(nn.Module):
    def __init__(self,
                reg_drop_rate=0.1,
                intermediate_size=256,
                num_targets=1):

        super(Transformer, self).__init__()
        self.reg_drop_rate = reg_drop_rate
        self.num_targets = num_targets
        self.intermediate_size = intermediate_size

        self.hidden_size = 256
        checkpoint_dict = torch.load('pytorch_model.bin', weights_only=True, map_location=device)
        self.transformer = PosEGNN(checkpoint_dict["config"])
        self.transformer.load_state_dict(checkpoint_dict["state_dict"], strict=False)
        self.regressor = nn.Sequential(
            nn.Dropout(self.reg_drop_rate),
            nn.Linear(self.hidden_size, self.intermediate_size),
            nn.SiLU(),
            nn.Dropout(self.reg_drop_rate),
            nn.Linear(self.intermediate_size, self.num_targets)
        )

    def forward(self, batch, layer_idx=-1):
        output_dict = self.transformer(batch)
        node_emb = output_dict["embedding_0"]
        if node_emb.dim() == 3:
            node_emb = node_emb[:, :, layer_idx]  # -> [total_nodes, hidden_dim]
        # `batch.batch` is the graph‐index for each node
        graph_emb = global_mean_pool(node_emb, batch.batch) # -> [batch_size, hidden_dim]
        return self.regressor(graph_emb)

In [None]:
tasks = ['caco2_wang', 'clearance_hepatocyte_az',
         'cyp2c9_substrate_carbonmangels', 'cyp3a4_substrate_carbonmangels',
         'cyp3a4_veith', 'dili', 'half_life_obach',
         'herg', 'ld50_zhu', 'lipophilicity_astrazeneca',
         'ppbr_az', 'solubility_aqsoldb']

In [None]:
torch.autograd.set_detect_anomaly(True)

base_dir = 'tdcommons/admet_group'
dfs = []

for task in tasks:
    if task.startswith('.'):
        continue

    print(task)
    
    # if task in ['cyp2d6_veith', 'hia_hou']: continue

    prefix = 'tdcommons/'
    if prefix+task in utils.tdc_mae_tasks:
        metric = 'mae'
    elif prefix+task in utils.tdc_spearman_task:
        metric = 'spearman'
    elif prefix+task in utils.polaris_pearson_tasks:
        metric = 'pearson'
    elif prefix+task in utils.tdc_auroc_tasks:
        metric = 'auc'
    elif prefix+task in utils.tdc_aucpr_tasks:
        metric = 'aucpr'
    elif prefix+task in utils.tdc_aucpr2_tasks:
        metric = 'aucpr'
    elif prefix+task in utils.polaris_aucpr_tasks:
        metric = 'aucpr'
    else:
        raise ValueError(f"Task {task} not found in any known task list.")

    try:
        data = ADME(name = task)
    except:
        data = Tox(name = task)

    split = data.get_split(method = 'scaffold')

    train_df = split['train'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    val_df = split['valid'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    test_df = split['test'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)

    if metric in ('mae', 'spearman', 'pearson'):
        scaler = StandardScaler()
        # fit only on train targets
        train_vals = train_df[['target']].values
        scaler.fit(train_vals)
        # add scaled targets
        train_df['target'] = scaler.transform(train_vals)
        val_df['target']   = scaler.transform(val_df[['target']].values)
        test_df['target']  = scaler.transform(test_df[['target']].values)

    for df, split in zip([train_df, val_df, test_df], ['train','val','test']):
        out_dir = f'3D_data/{task}/{split}/graphs'

        # If the folder already exists *and* has at least one .pt file, skip it
        if os.path.isdir(out_dir) and any(f.endswith('.pt') for f in os.listdir(out_dir)):
            continue

        os.makedirs(out_dir, exist_ok=True)

        for idx, row in df.iterrows():
            atoms = smiles_to_ase_atoms(row.smiles)
            if atoms is None:
                continue
            data = build_data_from_ase(atoms) 
            assert data.num_nodes > 0, f"no atoms in idx={idx}"
            data.y = torch.tensor([row.target], dtype=torch.float32)

            if torch.isnan(data.pos).any() or torch.isinf(data.pos).any():
                print(f"WARNING: NaN or Inf in positions for SMILES: {row.smiles}, idx: {idx}")
                # Decide how to handle: skip, log, etc.
                continue
            if torch.isnan(data.y).any() or torch.isinf(data.y).any():
                print(f"WARNING: NaN or Inf in target for SMILES: {row.smiles}, idx: {idx}")
                # This might indicate issues with the scaler or original data
                continue

            if not torch.isfinite(data.y).all():
                print(f"Skipping idx={idx} due to NaN target")
                continue

            torch.save(data, os.path.join(out_dir, f'{idx}.pt'))
        print(f"✔ Finished dumping {len(os.listdir(out_dir))} graphs for {task}/{split}")

    train_dataset = FileBacked3DDataset(f'3D_data/{task}/train/graphs')
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=4,
    )

    val_dataset = FileBacked3DDataset(f'3D_data/{task}/val/graphs')
    val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    test_dataset = FileBacked3DDataset(f'3D_data/{task}/test/graphs')
    test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

    num_layers = 4
    layer_indices = list(range(0, num_layers))

    test_metrics = []
    for layer_idx in layer_indices:
        print(f'Layer: {layer_idx}')

        accumulation_steps = 1
        # hyperparameter sweep over learning rates
        for lr in [1e-5, 2e-5, 5e-5, 1e-4, 2e-4]:
            print(f'LR={lr}')
            model = Transformer().to('cuda')
            optimizer = AdamW(model.parameters(), lr=lr)
            epochs = 50
            num_training_steps = len(train_dataloader) * epochs
            num_warmup_steps = int(0.05 * num_training_steps)
            scheduler_warmup = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
            # choose loss fn by task type
            if metric in ('mae', 'spearman', 'pearson'):
                loss_fn = MSELoss()
            else:
                loss_fn = BCEWithLogitsLoss()

            # train for 100 epochs
            # For MAE we want to minimize; for others maximize
            if metric == 'mae':
                best_val = float('inf')
            else:
                best_val = -float('inf')
            best_state = None
            best_epoch = -1
            for epoch in range(epochs):
                model.train()
                for step, batch in enumerate(train_dataloader):
                    if torch.isnan(batch.pos).any() or torch.isnan(batch.y).any():
                        print(f"NaN detected in input at step {step}")
                        print(batch.batch)
                        raise RuntimeError("Stopping due to NaN in batch")
                    optimizer.zero_grad()
                    preds = model(batch.to(device), layer_idx)
                    targets = batch.y.squeeze().to(device)

                    preds_for_loss = preds.squeeze(-1)
                    targets_for_loss = targets.squeeze()

                    loss = loss_fn(preds_for_loss, targets_for_loss)
                    loss = loss / accumulation_steps
                    # print(batch)
                    # print(preds_for_loss)
                    # print(targets_for_loss)
                    # print(loss)
                    loss.backward()

                    # every `accumulation_steps`, do an optimizer step + scheduler step
                    if (step + 1) % accumulation_steps == 0:
                        optimizer.step()
                        scheduler_warmup.step()
                        optimizer.zero_grad()

                # finish off any remaining gradients if dataset size % acc_steps != 0
                if (step + 1) % accumulation_steps != 0:
                    optimizer.step()
                    scheduler_warmup.step()
                    optimizer.zero_grad()

                # evaluate on validation set
                model.eval()
                val_preds, val_targs = [], []
                with torch.no_grad():
                    for batch in val_dataloader:
                        batch_on_device = batch.to(device)
                        preds_tensor = model(batch_on_device, layer_idx)
    
                        current_preds_list = preds_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
                        targs_tensor = batch_on_device.y # Targets from the batch on device
                        current_targs_list = targs_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
    
                        val_preds.extend(current_preds_list)
                        val_targs.extend(current_targs_list)
    
                # compute your chosen metric
                if metric == 'mae':
                    val_score = np.mean(np.abs(np.array(val_preds) - np.array(val_targs)))
                elif metric == 'spearman':
                    val_score = spearmanr(val_targs, val_preds)[0]
                elif metric == 'pearson':
                    val_score = pearsonr(val_targs, val_preds)[0]
                elif metric == 'auc':
                    val_score = roc_auc_score(val_targs, val_preds)
                elif metric == 'aucpr':
                    val_score = average_precision_score(val_targs, val_preds)
    
                improved = (metric == 'mae' and val_score < best_val) or (metric != 'mae' and val_score > best_val)
                if improved:
                    best_val   = val_score
                    best_state = copy.deepcopy(model.state_dict())
                    best_epoch = epoch
                    # print(f"    ↳ new best val_{metric}: {best_val:.4f} (epoch {best_epoch})")

        gc.collect()
        torch.cuda.empty_cache()

        # now evaluate that best model on the test set
        model = Transformer().to('cuda')
        model.load_state_dict(best_state)
        model.eval()
        test_preds, test_targs = [], []
        with torch.no_grad():
            for batch in test_dataloader:
                batch_on_device = batch.to(device)
                preds_tensor = model(batch_on_device, layer_idx)

                current_preds_list = preds_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
                targs_tensor = batch_on_device.y # Targets from the batch on device
                current_targs_list = targs_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list

                test_preds.extend(current_preds_list)
                test_targs.extend(current_targs_list)

        if metric in ('mae', 'spearman', 'pearson'):
            # bring preds back to original units
            test_preds = scaler.inverse_transform(
                np.array(test_preds).reshape(-1, 1)
            ).flatten()
            # use original test targets (unscaled)
            test_targs = scaler.inverse_transform(
                np.array(test_targs).reshape(-1, 1)
            ).flatten()

        if metric == 'mae':
            test_score = np.mean(np.abs(np.array(test_preds) - np.array(test_targs)))
        elif metric == 'spearman':
            test_score = spearmanr(test_targs, test_preds)[0]
        elif metric == 'pearson':
            test_score = pearsonr(test_targs, test_preds)[0]
        elif metric == 'auc':
            test_score = roc_auc_score(test_targs, test_preds)
        elif metric == 'aucpr':
            test_score = average_precision_score(test_targs, test_preds)

        test_metrics.append(test_score)

        print(f"Layer {layer_idx}: {test_score}")

    # save a DataFrame with one column per task and rows = layer indices
    results_df = pd.DataFrame({task: test_metrics}, index=layer_indices)
    dfs.append(results_df)
    results_df.to_csv(f"tmp/posegnn_{task}_layer_results.csv", index=False)

# dfs = pd.concat(dfs, axis=1)
# dfs.to_csv('./results_posegnn_finetune.csv', index=False)

In [None]:
# preds = model(batch.to(device), layer_idx)        # -> [B,1]
# preds = preds.squeeze(-1)                         # -> [B]

# nan_mask = torch.isnan(preds)
# if nan_mask.any():
#     bad_positions = nan_mask.nonzero(as_tuple=True)[0]
#     print("⚠️ NaN in preds at batch indices:", bad_positions.tolist())
#     # batch.idx gives you the original dataset index for each sample
#     bad_dataset_idxs = batch.idx[bad_positions].tolist()
#     print("… which correspond to dataset indices:", bad_dataset_idxs)
#     for di in bad_dataset_idxs:
#         print("  file:", train_dataset.paths[di])
#     raise RuntimeError("Stopping: found NaN in model output")