In [None]:
# ! pip install transformers==4.46.3
# ! pip install --upgrade torch_geometric
# ! pip install torch_nl==0.3
# ! pip install PyTDC
# ! pip install rdkit==2023.9.4
# ! pip install addict
# ! pip install ase==3.24.0

In [None]:
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import os
import sys
import gc
import random
import pickle

from tqdm import tqdm
import pandas as pd
import numpy as np
from rdkit import Chem
import ase
from ase.io import read
from ase import Atoms

from sklearn.model_selection import GroupKFold
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.stats import spearmanr, pearsonr

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import MSELoss, BCEWithLogitsLoss
from torch import distributed as dist
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool

from transformers import AutoTokenizer, AutoModel
from transformers import get_linear_schedule_with_warmup
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

# allow imports from parent directory
# sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)))
import sys
sys.path.append('../external_repos/')
from UniMol.unimol_tools.unimol_tools.models import UniMolModel
from UniMol.unimol_tools.unimol_tools.data.conformer import coords2unimol
from UniMol.unimol_tools.unimol_tools.data.dictionary import Dictionary
from tdc.benchmark_group import admet_group
from tdc.single_pred import ADME, Tox

sys.path.append('../')
import utils

In [None]:
dictionary = Dictionary().load('unimol.dict.txt')

In [None]:
def get_smiles(input_file):
    # Read the entire content of the XYZ file.
    with open(input_file, 'r') as f:
        xyz_data = f.read()

    # Split the content into lines.
    lines = xyz_data.splitlines()
    smiles = lines[1].strip()

    return smiles

class XYZFolderDataset():
    def __init__(
        self,
        df,                        # your pandas DataFrame
        xyz_dir: Union[str, Path],
        target_col: str,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.xyz_dir = Path(xyz_dir)
        self.target_col = target_col
        self.dtype = dtype or torch.get_default_dtype()

        # build smiles → path lookup
        self.xyz_dir = Path(xyz_dir)
        lookup = {}
        for p in self.xyz_dir.glob("*.xyz"):
            s = get_smiles(str(p))
            lookup[s] = p
        self._lookup = lookup

        df = df.reset_index(drop=True)
        mask = df['smiles'].isin(lookup)
        num_bad = (~mask).sum()
        if num_bad > 0:
            print(f"Warning: dropping {num_bad} rows with no matching .xyz file")
        self.df = df.loc[mask].reset_index(drop=True)

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        smiles = row['smiles']
        target = row[self.target_col]

        xyz_path = self._lookup.get(smiles)
        if xyz_path is None:
            raise KeyError(f"Could not find .xyz for SMILES `{smiles}` in {self.xyz_dir}")

        # read Atoms
        xyz = ase.io.read(str(xyz_path))
        # extract targets into atoms.info
        atoms = xyz.get_chemical_symbols()
        coords = xyz.get_positions()

        # to graph
        graph = coords2unimol(atoms, coords, dictionary)

        # return graph + scalar target
        return graph, torch.tensor(target, dtype=self.dtype)

    def get_atom(self, idx: int) -> ase.Atoms:
        """Return the raw ASE Atoms for this index."""
        row = self.df.iloc[idx]
        xyz_path = self._lookup[row['smiles']]
        return ase.io.read(str(xyz_path))

    def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]:
        """Return both ASE Atoms and the row’s metadata dict."""
        row = self.df.iloc[idx]
        xyz_path = self._lookup[row['smiles']]
        atoms = ase.io.read(str(xyz_path))
        return atoms, row.to_dict()

In [None]:
import torch
import torch.nn.functional as F

def collate_graph_and_targets(batch):
    """
    batch: list of (graph_dict, target_tensor)
      - graph_dict keys are NumPy arrays:
          'src_tokens'    -> [N]
          'src_distance'  -> [N, N]
          'src_coord'     -> [N, 3]
          'src_edge_type' -> [N, N]
      - target_tensor: torch.Tensor scalar

    Returns:
      batched_graph: dict of torch.Tensor with shapes
           'src_tokens':    [B, N_max]
           'src_distance':  [B, N_max, N_max]
           'src_coord':     [B, N_max, 3]
           'src_edge_type': [B, N_max, N_max]
      batched_targets: torch.Tensor [B]
    """
    graphs, targets = zip(*batch)
    # Determine max number of nodes in this batch
    max_nodes = max(g['src_tokens'].shape[0] for g in graphs)
    pad_id = dictionary.pad()  # assume dictionary is in scope

    padded_tokens = []
    padded_dist   = []
    padded_coord  = []
    padded_edge   = []

    for g in graphs:
        # convert to torch tensors
        tokens = torch.as_tensor(g['src_tokens'],    dtype=torch.long)
        dist   = torch.as_tensor(g['src_distance'],  dtype=torch.float)
        coord  = torch.as_tensor(g['src_coord'],     dtype=torch.float)
        edge   = torch.as_tensor(g['src_edge_type'], dtype=torch.long)

        n = tokens.size(0)
        pad_len = max_nodes - n

        # 1D pad: (pad_left, pad_right)
        tokens_p = F.pad(tokens, (0, pad_len), value=pad_id)
        # 2D pad: (last_dim_left, last_dim_right, second_last_left, second_last_right)
        dist_p   = F.pad(dist,   (0, pad_len, 0, pad_len), value=0.0)
        coord_p  = F.pad(coord,  (0, 0,       0, pad_len), value=0.0)
        edge_p   = F.pad(edge,   (0, pad_len, 0, pad_len), value=pad_id)

        padded_tokens.append(tokens_p)
        padded_dist.append(dist_p)
        padded_coord.append(coord_p)
        padded_edge.append(edge_p)

    batched_graph = {
        'src_tokens':    torch.stack(padded_tokens, dim=0),
        'src_distance':  torch.stack(padded_dist,     dim=0),
        'src_coord':     torch.stack(padded_coord,    dim=0),
        'src_edge_type': torch.stack(padded_edge,     dim=0),
    }
    batched_targets = torch.stack(targets, dim=0)

    return batched_graph, batched_targets

In [None]:
def make_loader(
    df,
    xyz_dir: str,
    target_col: str,
    batch_size: int,
    num_workers: int,
    shuffle:  bool,
    dtype=None,
):
    dataset = XYZFolderDataset(
        df=df,
        xyz_dir=xyz_dir,
        target_col=target_col,
        dtype=dtype,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_graph_and_targets,
        pin_memory=True,
    )
    return loader

In [None]:
class Model(nn.Module):
    def __init__(self,
                reg_drop_rate=0.1,
                reg_size=256,
                num_labels=1):

        super(Model, self).__init__()
        self.reg_drop_rate = reg_drop_rate
        self.num_targets = num_labels
        self.reg_size = reg_size

        self.gnn = UniMolModel(remove_hs=True)

        self.hidden_size = 512
        self.regressor = nn.Sequential(
            nn.Dropout(self.reg_drop_rate),
            nn.Linear(self.hidden_size, self.reg_size),
            nn.SiLU(),
            nn.Dropout(self.reg_drop_rate),
            nn.Linear(self.reg_size, self.num_targets)
        )

    def forward(self, batched_graph, layer_idx):
        node_emb = self.gnn(**batched_graph, layer_idx=layer_idx, return_repr=True)
        node_emb = node_emb["intermediate_layer_reprs"][-1]
        graph_emb = node_emb[:, 0, :]
        output = self.regressor(graph_emb)
        return graph_emb, output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.float32
compile = None

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32  = True

In [None]:
model = Model();

In [None]:
def move_to_device(batch, device):
    if hasattr(batch, 'to'):
        return batch.to(device)
    elif isinstance(batch, dict):
        return {k: move_to_device(v, device) for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):
        return type(batch)(move_to_device(v, device) for v in batch)
    elif isinstance(batch, torch.Tensor):
        return batch.to(device)
    else:
        return batch

In [None]:
tasks = ['solubility_aqsoldb'
]

In [None]:
import copy
from sklearn.preprocessing import StandardScaler

torch.autograd.set_detect_anomaly(True)

base_dir = 'tdcommons/admet_group'
dfs = []

for task in tasks:
    if task.startswith('.'):
        continue

    print(task)

    # if task in skip: continue

    prefix = 'tdcommons/'
    if prefix+task in utils.tdc_mae_tasks:
        metric = 'mae'
    elif prefix+task in utils.tdc_spearman_task:
        metric = 'spearman'
    elif prefix+task in utils.polaris_pearson_tasks:
        metric = 'pearson'
    elif prefix+task in utils.tdc_auroc_tasks:
        metric = 'auc'
    elif prefix+task in utils.tdc_aucpr_tasks:
        metric = 'aucpr'
    elif prefix+task in utils.tdc_aucpr2_tasks:
        metric = 'aucpr'
    elif prefix+task in utils.polaris_aucpr_tasks:
        metric = 'aucpr'
    else:
        raise ValueError(f"Task {task} not found in any known task list.")

    try:
        data = ADME(name = task)
    except:
        data = Tox(name = task)

    split = data.get_split(method = 'scaffold')

    train_df = split['train'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    val_df = split['valid'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    test_df = split['test'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)

    if metric in ('mae', 'spearman', 'pearson'):
        scaler = StandardScaler()
        # fit only on train targets
        train_vals = train_df[['target']].values
        scaler.fit(train_vals)
        # add scaled targets
        train_df['target'] = scaler.transform(train_vals)
        val_df['target']   = scaler.transform(val_df[['target']].values)
        test_df['target']  = scaler.transform(test_df[['target']].values)

    train_dataloader = make_loader(
        df=train_df,
        xyz_dir=f'xyz_files/{task}/train/graphs',
        target_col='target',
        batch_size=32,
        num_workers=4,
        dtype=torch.float32,
        shuffle=True
    )

    val_dataloader = make_loader(
        df=val_df,
        xyz_dir=f'xyz_files/{task}/val/graphs',
        target_col='target',
        batch_size=64,
        num_workers=1,
        dtype=torch.float32,
        shuffle=False
    )

    test_dataloader = make_loader(
        df=test_df,
        xyz_dir=f'xyz_files/{task}/test/graphs',
        target_col='target',
        batch_size=64,
        num_workers=1,
        dtype=torch.float32,
        shuffle=False
    )

    num_layers = len(model.gnn.encoder.layers)
    layer_indices = list(range(0, num_layers))

    test_metrics = []
    for layer_idx in layer_indices:
        print(f'Layer: {layer_idx}')

        accumulation_steps = 2
        # hyperparameter sweep over learning rates
        for lr in [1e-5, 2e-5, 5e-5, 1e-4, 2e-4]:
            print(f'LR={lr}')
            model = Model(reg_size=32).to('cuda')
            optimizer = AdamW(model.parameters(), lr=lr)
            epochs = 50
            num_training_steps = len(train_dataloader) * epochs
            num_warmup_steps = int(0.05 * num_training_steps)
            scheduler_warmup = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
            # choose loss fn by task type
            if metric in ('mae', 'spearman', 'pearson'):
                loss_fn = MSELoss()
            else:
                loss_fn = BCEWithLogitsLoss()

            # train for 100 epochs
            # For MAE we want to minimize; for others maximize
            if metric == 'mae':
                best_val = float('inf')
            else:
                best_val = -float('inf')
            best_state = None
            best_epoch = -1
            for epoch in tqdm(range(epochs)):
                model.train()
                for step, (mol_graphs, targets) in enumerate(train_dataloader):
                    mol_graphs = move_to_device(mol_graphs, device)
                    targets = targets.to(device)
                    optimizer.zero_grad()
                    _, preds = model(mol_graphs, layer_idx)

                    preds_for_loss = preds.squeeze(-1)
                    targets_for_loss = targets.squeeze()

                    loss = loss_fn(preds_for_loss, targets_for_loss)
                    loss = loss / accumulation_steps
                    # print(batch)
                    # print(preds_for_loss)
                    # print(targets_for_loss)
                    # print(loss)
                    loss.backward()

                    # every `accumulation_steps`, do an optimizer step + scheduler step
                    if (step + 1) % accumulation_steps == 0:
                        optimizer.step()
                        scheduler_warmup.step()
                        optimizer.zero_grad()

                # finish off any remaining gradients if dataset size % acc_steps != 0
                if (step + 1) % accumulation_steps != 0:
                    optimizer.step()
                    scheduler_warmup.step()
                    optimizer.zero_grad()

                # evaluate on validation set
                model.eval()
                gc.collect()
                torch.cuda.empty_cache()
                val_preds, val_targs = [], []
                with torch.no_grad():
                    for mol_graphs, targets in val_dataloader:
                        mol_graphs = move_to_device(mol_graphs, device)
                        targets = targets.to(device)
                        _, preds_tensor = model(mol_graphs, layer_idx)

                        current_preds_list = preds_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
                        current_targs_list = targets.view(-1).cpu().numpy().tolist() # Flattens to [B] list

                        val_preds.extend(current_preds_list)
                        val_targs.extend(current_targs_list)

                # compute your chosen metric
                if metric == 'mae':
                    val_score = np.mean(np.abs(np.array(val_preds) - np.array(val_targs)))
                elif metric == 'spearman':
                    val_score = spearmanr(val_targs, val_preds)[0]
                elif metric == 'pearson':
                    val_score = pearsonr(val_targs, val_preds)[0]
                elif metric == 'auc':
                    val_score = roc_auc_score(val_targs, val_preds)
                elif metric == 'aucpr':
                    val_score = average_precision_score(val_targs, val_preds)

                improved = (metric == 'mae' and val_score < best_val) or (metric != 'mae' and val_score > best_val)
                if improved:
                    best_val   = val_score
                    best_state = copy.deepcopy(model.state_dict())
                    best_epoch = epoch
                    # print(f"    ↳ new best val_{metric}: {best_val:.4f} (epoch {best_epoch})")

        del(model)
        gc.collect()
        torch.cuda.empty_cache()

        # now evaluate that best model on the test set
        model = Model(reg_size=32).to('cuda')
        model.load_state_dict(best_state)
        model.eval()
        test_preds, test_targs = [], []
        with torch.no_grad():
            for mol_graphs, targets in test_dataloader:
                mol_graphs = move_to_device(mol_graphs, device)
                targets = targets.to(device)
                _, preds_tensor = model(mol_graphs, layer_idx)

                current_preds_list = preds_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
                current_targs_list = targets.view(-1).cpu().numpy().tolist() # Flattens to [B] list

                test_preds.extend(current_preds_list)
                test_targs.extend(current_targs_list)

        if metric in ('mae', 'spearman', 'pearson'):
            # bring preds back to original units
            test_preds = scaler.inverse_transform(
                np.array(test_preds).reshape(-1, 1)
            ).flatten()
            # use original test targets (unscaled)
            test_targs = scaler.inverse_transform(
                np.array(test_targs).reshape(-1, 1)
            ).flatten()

        if metric == 'mae':
            test_score = np.mean(np.abs(np.array(test_preds) - np.array(test_targs)))
        elif metric == 'spearman':
            test_score = spearmanr(test_targs, test_preds)[0]
        elif metric == 'pearson':
            test_score = pearsonr(test_targs, test_preds)[0]
        elif metric == 'auc':
            test_score = roc_auc_score(test_targs, test_preds)
        elif metric == 'aucpr':
            test_score = average_precision_score(test_targs, test_preds)

        test_metrics.append(test_score)

        print(f"Layer {layer_idx}: {test_score}")

    # save a DataFrame with one column per task and rows = layer indices
    results_df = pd.DataFrame({task: test_metrics}, index=layer_indices)
    dfs.append(results_df)
    results_df.to_csv(f"tmp/unimol1_{task}_results.csv", index=False)