In [None]:
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import os, sys, gc
import pandas as pd
from rdkit.Chem import AllChem
import numpy as np
from rdkit import Chem
import ase

from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.stats import spearmanr, pearsonr

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

from torch.optim import AdamW
from torch.nn import MSELoss, BCEWithLogitsLoss
from torch_geometric.nn import global_mean_pool

import orb_models.utils as utils
from orb_models.forcefield.base import AtomGraphs
from orb_models.dataset.augmentations import rotate_randomly
from orb_models.dataset.base_datasets import AtomsDataset

sys.path.append('../external_repos/')
from orb_models_modified.orb_models.forcefield import base, pretrained, atomic_system, property_definitions

sys.path.append('../')
import utils as tdc_utils

from tdc.single_pred import ADME, Tox

In [None]:
# from tdc.benchmark_group import admet_group

# group = admet_group(path = 'tdcommons/')
# benchmark = group.get('ames')

# name = benchmark['name']
# train_val, test = benchmark['train_val'], benchmark['test']

In [None]:
import hashlib

def smiles_to_xyz_file(smiles: str, output_dir: str = 'xyz_files/', random_seed: int = 42,
                      jitter_amp: float = 1e-2) -> str | None:

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Convert SMILES to 3D structure with hydrogens (for better geometry)
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None

    mol = Chem.AddHs(mol)

    # Try different embedding methods in order of preference
    conformer_generated = False
    
    # Method 1: Standard embedding with UFF optimization
    try:
        embed_result = AllChem.EmbedMolecule(mol, maxAttempts=2000, randomSeed=random_seed, useRandomCoords=True)
        if embed_result == 0:  # Success
            # UFF optimization
            uff_props = AllChem.UFFGetMoleculeForceField(mol)
            if uff_props is not None:
                uff_props.Initialize()
                uff_props.Minimize(maxIts=1000)
                
                # Add jitter to avoid symmetry issues
                conf = mol.GetConformer()
                pos = conf.GetPositions()
                np.random.seed(random_seed)
                noise = (np.random.rand(*pos.shape) - 0.5) * jitter_amp
                new_pos = pos + noise
                for i, p in enumerate(new_pos):
                    conf.SetAtomPosition(i, p.tolist())
                
                # Final MMFF optimization
                AllChem.MMFFOptimizeMolecule(mol)
                conformer_generated = True
    except Exception as e:
        print(f"Method 1 failed: {e}")
        pass
    
    # Method 2: Try ETKDGv3 if method 1 failed
    if not conformer_generated:
        try:
            embed_result = AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
            if embed_result == 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 2 failed: {e}")
            pass
    
    # Method 3: Try explicit conformer creation with 2D coords
    if not conformer_generated:
        try:
            # Create a fresh conformer explicitly
            conf = Chem.Conformer(mol.GetNumAtoms())
            for i in range(mol.GetNumAtoms()):
                conf.SetAtomPosition(i, [0.0, 0.0, 0.0])  # Initialize with zeros
            conf_id = mol.AddConformer(conf)
            
            # Now compute 2D coords
            AllChem.Compute2DCoords(mol)
            
            # Check if the conformer exists and has non-zero coordinates
            if mol.GetNumConformers() > 0:
                conformer_generated = True
        except Exception as e:
            print(f"Method 3 failed: {e}")
            pass

    # Bail out if we couldn't generate any conformer
    if not conformer_generated or mol.GetNumConformers() == 0:
        return None

    # Get the conformer
    conf = mol.GetConformer()

    # 5) Check for overlaps & retry if needed
    def has_overlaps(pos, eps=1e-6):
        d = torch.cdist(torch.tensor(pos), torch.tensor(pos))
        n = d.shape[0]
        d[range(n), range(n)] = float('inf')
        return (d <= eps).any().item()

    pos = conf.GetPositions()
    if has_overlaps(pos, eps=1e-6):
        # you could loop a few times here with new seeds,
        # or simply bail out and let your dataset drop it:
        return None
    
    # Extract heavy atoms only
    heavy_atoms = []
    for atom in mol.GetAtoms():
        if atom.GetSymbol() != 'H':
            idx = atom.GetIdx()
            try:
                pos = conf.GetAtomPosition(idx)
                heavy_atoms.append((atom.GetSymbol(), pos))
            except Exception:
                # Skip this atom if position can't be retrieved
                continue

    # If no heavy atoms were successfully processed, fail
    if not heavy_atoms:
        return None

    # Build filename
    h = hashlib.md5(smiles.encode()).hexdigest()[:10]
    filename = f"{h}.xyz"
    path = os.path.join(output_dir, filename)

    # Write out XYZ
    with open(path, 'w') as f:
        f.write(f"{len(heavy_atoms)}\n")
        f.write(f"{smiles}\n")
        for sym, pt in heavy_atoms:
            f.write(f"{sym} {pt.x:.6f} {pt.y:.6f} {pt.z:.6f}\n")

    return path

In [None]:
def canonicalize_smiles(smiles):
    """Converts a SMILES string to its canonical form."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None  # or raise an exception, depending on your needs
        return Chem.MolToSmiles(mol, canonical=True)
    except:
        return None

In [None]:
def get_smiles(input_file):
    # Read the entire content of the XYZ file.
    with open(input_file, 'r') as f:
        xyz_data = f.read()

    # Split the content into lines.
    lines = xyz_data.splitlines()
    smiles = lines[1].strip()

    return smiles #canonicalize_smiles(smiles)

In [None]:
class XYZFolderDataset(AtomsDataset):
    """
    A Dataset that mimics AseSqliteDataset but reads single .xyz files
    based on metal_smiles from a DataFrame.
    """

    def __init__(
        self,
        df,                        # your pandas DataFrame
        xyz_dir: Union[str, Path],
        system_config: atomic_system.SystemConfig,
        target_col: str,
        target_config: Optional[property_definitions.PropertyConfig] = None,
        augmentations: Optional[List[Callable[[ase.Atoms], None]]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(name="xyz_folder", system_config=system_config, augmentations=augmentations)
        self.df = df.reset_index(drop=True)
        self.xyz_dir = Path(xyz_dir)
        self.target_col = target_col
        self.dtype = dtype or torch.get_default_dtype()
        self.target_config = target_config or property_definitions.PropertyConfig()
        self.constraints: List[Callable] = []

        # build smiles → path lookup
        self.xyz_dir = Path(xyz_dir)
        lookup = {}
        for p in self.xyz_dir.glob("*.xyz"):
            s = get_smiles(str(p))
            lookup[s] = p
        self._lookup = lookup

        df = df.reset_index(drop=True)
        mask = df['smiles'].isin(lookup)
        num_bad = (~mask).sum()
        if num_bad > 0:
            print(f"Warning: dropping {num_bad} rows with no matching .xyz file")
        self.df = df.loc[mask].reset_index(drop=True)

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[AtomGraphs, torch.Tensor]:
        row = self.df.iloc[idx]
        smiles = row['smiles']
        target = row[self.target_col]

        xyz_path = self._lookup.get(smiles)
        if xyz_path is None:
            raise KeyError(f"Could not find .xyz for SMILES `{smiles}` in {self.xyz_dir}")

        # read Atoms
        atoms = ase.io.read(str(xyz_path))
        # extract targets into atoms.info
        atoms.info = {}
        atoms.info.update(self.target_config.extract(row.to_dict(), self.name, "targets"))

        # augmentations
        for aug in self.augmentations or []:
            aug(atoms)

        # constraints
        atoms.set_constraint(None)
        for c in self.constraints:
            c(atoms, {}, self.name)

        # to graph
        graph = atomic_system.ase_atoms_to_atom_graphs(
            atoms=atoms,
            system_config=self.system_config,
            edge_method="knn_scipy",
            wrap=True,
            system_id=idx,
            output_dtype=self.dtype,
            graph_construction_dtype=self.dtype,
        )

        # return graph + scalar target
        return graph, torch.tensor(target, dtype=self.dtype), idx

    def get_atom(self, idx: int) -> ase.Atoms:
        """Return the raw ASE Atoms for this index."""
        row = self.df.iloc[idx]
        xyz_path = self._lookup[row['smiles']]
        return ase.io.read(str(xyz_path))

    def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]:
        """Return both ASE Atoms and the row’s metadata dict."""
        row = self.df.iloc[idx]
        xyz_path = self._lookup[row['smiles']]
        atoms = ase.io.read(str(xyz_path))
        return atoms, row.to_dict()

In [None]:
def collate_graph_and_targets(batch):
    # batch is a list of (AtomGraphs, tensor) tuples
    graphs, targets, idxs = zip(*batch)
    batched_graph = base.batch_graphs(graphs)
    batched_targets = torch.stack(targets, dim=0)
    return batched_graph, batched_targets, torch.tensor(idxs)

In [None]:
def make_train_loader(
    df,
    xyz_dir: str,
    system_config,
    target_col: str,
    batch_size: int,
    num_workers: int,
    target_config=None,
    augmentations=None,
    dtype=None,
):
    dataset = XYZFolderDataset(
        df=df,
        xyz_dir=xyz_dir,
        system_config=system_config,
        target_col=target_col,
        target_config=target_config,
        augmentations=augmentations,
        dtype=dtype,
    )

    # sampler = RandomSampler(dataset)
    # batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=True)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        worker_init_fn=utils.worker_init_fn,
        collate_fn=collate_graph_and_targets,
        pin_memory=True,
    )
    return loader

def make_val_loader(
    df,
    xyz_dir: str,
    system_config,
    target_col: str,
    batch_size: int,
    num_workers: int,
    target_config=None,
    augmentations=None,
    dtype=None,
):
    dataset = XYZFolderDataset(
        df=df,
        xyz_dir=xyz_dir,
        system_config=system_config,
        target_col=target_col,
        target_config=target_config,
        augmentations=augmentations,
        dtype=dtype,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        worker_init_fn=utils.worker_init_fn,
        collate_fn=collate_graph_and_targets,
        pin_memory=True,
    )
    return loader

In [None]:
class Model(nn.Module):
    def __init__(self,
                reg_drop_rate=0.1,
                reg_size=256,
                num_labels=1):

        super(Model, self).__init__()
        self.reg_drop_rate = reg_drop_rate
        self.num_targets = num_labels
        self.reg_size = reg_size

        # self.gnn = pretrained.orb_v3_direct_inf_mpa(
        # device='cpu',
        # precision="float32-highest",   # or "float32-highest" / "float64
        # )
        self.gnn = pretrained.orb_v3_conservative_inf_omat(
        device='cpu',
        precision="float32-highest",   # or "float32-highest" / "float64
        )

        self.hidden_size = 256
        self.regressor = nn.Sequential(
            nn.Dropout(self.reg_drop_rate),
            nn.Linear(self.hidden_size, self.reg_size),
            nn.SiLU(),
            nn.Dropout(self.reg_drop_rate),
            nn.Linear(self.reg_size, self.num_targets)
        )

    def forward(self, batched_graph, layer_idx):
        node_emb = self.gnn(batched_graph)
        node_emb = node_emb["intermediate_layers"][layer_idx]
        batch_index = batched_graph._get_per_node_graph_indices().long()
        graph_emb = global_mean_pool(node_emb, batch_index)
        output = self.regressor(graph_emb)
        return graph_emb, output

In [None]:
model = Model()
device = 'cuda'

In [None]:
import copy
from sklearn.preprocessing import StandardScaler

torch.autograd.set_detect_anomaly(True)

base_dir = '../input_data/tdcommons/admet_group'
dfs = []

for task in os.listdir(base_dir):
    if task.startswith('.'):
        continue

    print(task)

    # if task in skip: 
    #     print('Skipping ', task)
    #     continue

    prefix = 'tdcommons/'
    if prefix+task in tdc_utils.tdc_mae_tasks:
        metric = 'mae'
    elif prefix+task in tdc_utils.tdc_spearman_task:
        metric = 'spearman'
    elif prefix+task in tdc_utils.polaris_pearson_tasks:
        metric = 'pearson'
    elif prefix+task in tdc_utils.tdc_auroc_tasks:
        metric = 'auc'
    elif prefix+task in tdc_utils.tdc_aucpr_tasks:
        metric = 'aucpr'
    elif prefix+task in tdc_utils.tdc_aucpr2_tasks:
        metric = 'aucpr'
    elif prefix+task in tdc_utils.polaris_aucpr_tasks:
        metric = 'aucpr'
    else:
        raise ValueError(f"Task {task} not found in any known task list.")

    try:
        data = Tox(name = task)
    except:
        data = ADME(name = task)

    split = data.get_split(method = 'scaffold')

    train_df = split['train'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    val_df = split['valid'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)
    test_df = split['test'].rename({'Drug': 'smiles', 'Y': 'target'}, axis=1).drop('Drug_ID', axis=1)

    if metric in ('mae', 'spearman', 'pearson'):
        scaler = StandardScaler()
        # fit only on train targets
        train_vals = train_df[['target']].values
        scaler.fit(train_vals)
        # add scaled targets
        train_df['target'] = scaler.transform(train_vals)
        val_df['target']   = scaler.transform(val_df[['target']].values)
        test_df['target']  = scaler.transform(test_df[['target']].values)

    for df, split in zip([train_df, val_df, test_df], ['train', 'val', 'test']):
        out_dir = f'xyz_files/{task}/{split}/graphs'

        # If the folder already exists *and* has at least one .pt file, skip it
        if os.path.isdir(out_dir) and any(f.endswith('.xyz') for f in os.listdir(out_dir)):
            print('Already exists')
            continue

        os.makedirs(out_dir, exist_ok=True)

        for smi in df['smiles'].tolist():
            smiles_to_xyz_file(smi, out_dir)

    train_dataloader = make_train_loader(
        df=train_df,
        xyz_dir=f'xyz_files/{task}/train/graphs',
        system_config=model.gnn.system_config,
        target_col='target',
        batch_size=64,
        num_workers=4,
        target_config=None,
        augmentations=[rotate_randomly],
        dtype=torch.float32,
    )

    val_dataloader = make_val_loader(
        df=val_df,
        xyz_dir=f'xyz_files/{task}/val/graphs',
        system_config=model.gnn.system_config,
        target_col='target',
        batch_size=256,
        num_workers=1,
        target_config=None,
        augmentations=None,
        dtype=torch.float32,
    )

    test_dataloader = make_val_loader(
        df=test_df,
        xyz_dir=f'xyz_files/{task}/test/graphs',
        system_config=model.gnn.system_config,
        target_col='target',
        batch_size=256,
        num_workers=1,
        target_config=None,
        augmentations=None,
        dtype=torch.float32,
    )

    num_layers = 5
    layer_indices = list(range(0, num_layers))

    test_metrics = []
    for layer_idx in layer_indices:
        print(f'Layer: {layer_idx}')

        accumulation_steps = 1
        # hyperparameter sweep over learning rates
        for lr in [1e-5, 2e-5, 5e-5, 1e-4, 2e-4]:
            print(f'LR={lr}')
            model = Model(reg_size=256).to('cuda')
            optimizer = AdamW(model.parameters(), lr=lr)
            epochs = 50
            num_training_steps = len(train_dataloader) * epochs
            num_warmup_steps = int(0.05 * num_training_steps)
            scheduler_warmup = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
            # choose loss fn by task type
            if metric in ('mae', 'spearman', 'pearson'):
                loss_fn = MSELoss()
            else:
                loss_fn = BCEWithLogitsLoss()

            # train for 100 epochs
            # For MAE we want to minimize; for others maximize
            if metric == 'mae':
                best_val = float('inf')
            else:
                best_val = -float('inf')
            best_state = None
            best_epoch = -1
            for epoch in range(epochs):
                model.train()
                for step, (mol_graphs, targets, idxs) in enumerate(train_dataloader):
                    mol_graphs = mol_graphs.to(device)
                    targets = targets.to(device)
                    optimizer.zero_grad()
                    _, preds = model(mol_graphs, layer_idx)

                    preds_for_loss = preds.squeeze(-1)
                    targets_for_loss = targets.squeeze()

                    loss = loss_fn(preds_for_loss, targets_for_loss)
                    loss = loss / accumulation_steps
                    # print(batch)
                    # print(preds_for_loss)
                    # print(targets_for_loss)
                    # print(loss)
                    loss.backward()

                    # every `accumulation_steps`, do an optimizer step + scheduler step
                    if (step + 1) % accumulation_steps == 0:
                        optimizer.step()
                        scheduler_warmup.step()
                        optimizer.zero_grad()

                # finish off any remaining gradients if dataset size % acc_steps != 0
                if (step + 1) % accumulation_steps != 0:
                    optimizer.step()
                    scheduler_warmup.step()
                    optimizer.zero_grad()

                # evaluate on validation set
                model.eval()
                val_preds, val_targs = [], []
                with torch.no_grad():
                    for mol_graphs, targets, idxs in val_dataloader:
                        mol_graphs = mol_graphs.to(device)
                        targets = targets.to(device)
                        _, preds_tensor = model(mol_graphs, layer_idx)

                        current_preds_list = preds_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
                        current_targs_list = targets.view(-1).cpu().numpy().tolist() # Flattens to [B] list

                        val_preds.extend(current_preds_list)
                        val_targs.extend(current_targs_list)

                # compute your chosen metric
                if metric == 'mae':
                    val_score = np.mean(np.abs(np.array(val_preds) - np.array(val_targs)))
                elif metric == 'spearman':
                    val_score = spearmanr(val_targs, val_preds)[0]
                elif metric == 'pearson':
                    val_score = pearsonr(val_targs, val_preds)[0]
                elif metric == 'auc':
                    val_score = roc_auc_score(val_targs, val_preds)
                elif metric == 'aucpr':
                    val_score = average_precision_score(val_targs, val_preds)

                improved = (metric == 'mae' and val_score < best_val) or (metric != 'mae' and val_score > best_val)
                if improved:
                    best_val   = val_score
                    best_state = copy.deepcopy(model.state_dict())
                    best_epoch = epoch
                    # print(f"    ↳ new best val_{metric}: {best_val:.4f} (epoch {best_epoch})")

        gc.collect()
        torch.cuda.empty_cache()

        # now evaluate that best model on the test set
        model = Model(reg_size=256).to('cuda')
        model.load_state_dict(best_state)
        model.eval()
        test_preds, test_targs = [], []
        with torch.no_grad():
            for mol_graphs, targets, idxs in test_dataloader:
                mol_graphs = mol_graphs.to(device)
                targets = targets.to(device)
                _, preds_tensor = model(mol_graphs, layer_idx)

                current_preds_list = preds_tensor.view(-1).cpu().numpy().tolist() # Flattens to [B] list
                current_targs_list = targets.view(-1).cpu().numpy().tolist() # Flattens to [B] list

                test_preds.extend(current_preds_list)
                test_targs.extend(current_targs_list)

        if metric in ('mae', 'spearman', 'pearson'):
            # bring preds back to original units
            test_preds = scaler.inverse_transform(
                np.array(test_preds).reshape(-1, 1)
            ).flatten()
            # use original test targets (unscaled)
            test_targs = scaler.inverse_transform(
                np.array(test_targs).reshape(-1, 1)
            ).flatten()

        if metric == 'mae':
            test_score = np.mean(np.abs(np.array(test_preds) - np.array(test_targs)))
        elif metric == 'spearman':
            test_score = spearmanr(test_targs, test_preds)[0]
        elif metric == 'pearson':
            test_score = pearsonr(test_targs, test_preds)[0]
        elif metric == 'auc':
            test_score = roc_auc_score(test_targs, test_preds)
        elif metric == 'aucpr':
            test_score = average_precision_score(test_targs, test_preds)

        test_metrics.append(test_score)

        print(f"Layer {layer_idx}: {test_score}")

    # save a DataFrame with one column per task and rows = layer indices
    results_df = pd.DataFrame({task: test_metrics}, index=layer_indices)
    dfs.append(results_df)
    results_df.to_csv(f"tmp/orb_conserv_{task}_results.csv", index=False)

In [None]:
dfs = []
for df in os.listdir('tmp_256'):
    if 'conserv' in df:
        dfs.append(pd.read_csv('tmp_256/'+df))

In [None]:
pd.concat(dfs, axis=1).to_csv('./results_orb_conserv_finetune.csv', index=False)

In [None]:
# nan_mask = torch.isnan(preds)
# if nan_mask.any():
#     bad_positions = nan_mask.nonzero(as_tuple=True)[0]
#     print("⚠️ NaN in preds at batch indices:", bad_positions.tolist())
    # batch.idx gives you the original dataset index for each sample
    # bad_dataset_idxs = batch.idx[bad_positions].tolist()
    # print("… which correspond to dataset indices:", bad_dataset_idxs)
    # for di in bad_dataset_idxs:
    #     print("  file:", train_dataset.paths[di])
    # raise RuntimeError("Stopping: found NaN in model output")

In [None]:
# for step, (mol_graphs, targets, idxs) in enumerate(train_dataloader):
#     mol_graphs = mol_graphs.to('cuda')
#     idxs = idxs.to('cuda')
#     with torch.no_grad():
#         _, preds = model(mol_graphs, layer_idx)
#     nan_mask = torch.isnan(preds)
#     if nan_mask.any():
#         bad_batch_pos = nan_mask.nonzero(as_tuple=True)[0].unique()
#         bad_dataset_idx = idxs[bad_batch_pos].tolist()
#         print("⚠️ NaN preds at batch positions", bad_batch_pos.tolist(),
#               "→ dataset rows", bad_dataset_idx)
#         # now inspect each bad graph
#         for di in bad_dataset_idx:
#             g, t, _ = train_dataloader.dataset[di]
#             print(f"\n--- Inspecting dataset row {di}:")
#             print(train_dataloader.dataset.df.iloc[di])
#             # check graph-level stats
#             print(f"  n_nodes = {g.n_node}, n_edges = {g.n_edge}")
#             # assume node features live in g.node_features
#             nf = g.node_features
#             ef = g.edge_features   # adjust to your field names
#             print("  node_features has NaN?", torch.isnan(nf).any().item())
#             print("  edge_features has NaN?", torch.isnan(ef).any().item())
#             # if you store positions on the graph:
#             pos = g.node_positions
#             print("  positions has NaN?", torch.isnan(pos).any().item())
#         break  # stop after first failure to keep logs manageable

In [None]:
# import torch

# pos = g.node_features["positions"]  # shape: (n_nodes, 3)
# # Compute full pairwise distance matrix
# # (PyTorch ≥1.1 has torch.cdist; if you’re on older PyTorch, you can do it with broadcasting.)
# dists = torch.cdist(pos, pos)  # shape: (n_nodes, n_nodes)

# # Zero out the diagonal so we only look at *distinct* atom pairs
# n = pos.size(0)
# eye = torch.eye(n, dtype=torch.bool, device=dists.device)
# dists[eye] = float("inf")

# # Find all pairs closer than eps (e.g. exactly zero or nearly zero)
# eps = 1e-6
# overlap_mask = dists <= eps
# overlaps = overlap_mask.nonzero(as_tuple=False)  # each row is [i, j]

# if overlaps.numel() == 0:
#     print("✔️ No overlapping atoms detected.")
# else:
#     print("⚠️ Overlapping atom pairs (i, j):")
#     for i, j in overlaps.tolist():
#         print(f"   Atom {i} and Atom {j} both at {pos[i].tolist()}")


In [None]:
# # assume g is the single-graph AtomGraphs you pulled out for idx 6128

# print(f"\nGraph has {g.n_node.item()} nodes, {g.n_edge.item()} edges\n")

# # 1) Check every node-feature tensor
# for name, feat in g.node_features.items():
#     has_nan = torch.isnan(feat).any().item()
#     print(f"node_feature[{name}]: shape={tuple(feat.shape)}, NaN? {has_nan}")
#     if has_nan:
#         print("   →", name, "min/max:", feat.nanmin().item(), feat.nanmax().item())

# # 2) Check every edge-feature tensor
# for name, feat in g.edge_features.items():
#     has_nan = torch.isnan(feat).any().item()
#     print(f"edge_feature[{name}]: shape={tuple(feat.shape)}, NaN? {has_nan}")
#     if has_nan:
#         print("   →", name, "min/max:", feat.nanmin().item(), feat.nanmax().item())

# # 3) If you store positions explicitly on g:
# if hasattr(g, "node_positions"):
#     pos = g.node_positions
#     print("node_positions:", pos.shape, "NaN?", torch.isnan(pos).any().item())
