In [6]:
import os
import pickle
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx

import sys
from mmfdl.util.data_gen_modify import make_variable_one
from mmfdl.util.utils import formDataset_Single

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

In [7]:
def atom_features(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        [atom.GetIsAromatic()])

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(f"input {x} not in allowable set{allowable_set}:")
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None, None, None
    
    c_size = mol.GetNumAtoms()
    
    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))

    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])
        
    return c_size, features, edge_index

def process_csv_to_pt(csv_path, vocab_path, results_dir, dataset_name, 
                      input_col='SMILES', target_col='Ssel', max_smiles_len=44, ecfp_bits=2048):
    """
    Generate .pt from csv
    
    Parameters:
    -----------
    csv_path : str
    vocab_path : str
    results_dir : str
    dataset_name : str
    input_col : str
    target_col : str
    max_smiles_len : int
    ecfp_bits : int
    """
    print(f'\nProcessing {csv_path}...')
    
    df = pd.read_csv(csv_path)
    if input_col not in df.columns:
        raise ValueError(f"Input column '{input_col}' not found in CSV. Available columns: {df.columns.tolist()}")
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found in CSV. Available columns: {df.columns.tolist()}")
    
    smiles_list = df[input_col].dropna().tolist()
    labels = df[target_col].dropna().tolist()
    
    min_len = min(len(smiles_list), len(labels))
    smiles_list = smiles_list[:min_len]
    labels = labels[:min_len]
    
    print(f'  Found {len(smiles_list)} samples')
    
    with open(vocab_path, 'rb') as f:
        smilesVoc = pickle.load(f)
    
    encoded_smi_list = []
    ecfp_list = []
    labels_list = []
    smile_graph_dict = {}
    
    valid_count = 0
    for idx, (smi, label) in enumerate(zip(smiles_list, labels)):
        if pd.isna(smi) or pd.isna(label):
            continue
        
        try:
            # SMILES encoding
            encoded_smi = make_variable_one(smi, smilesVoc, max_smiles_len)
            
            # Generate ECFP
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            if mol.HasSubstructMatch(Chem.MolFromSmarts("[H]")):
                mol = Chem.RemoveHs(mol)
            ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=ecfp_bits)
            ecfp_array = np.array(ecfp, dtype=np.float32)
            
            # Generate Graph
            c_size, features, edge_index = smile_to_graph(smi)
            if edge_index == [] or features is None:
                continue
            
            encoded_smi_list.append(encoded_smi)
            ecfp_list.append(ecfp_array.tolist())
            labels_list.append(float(label))
            smile_graph_dict[valid_count] = (c_size, features, edge_index)
            valid_count += 1
            
        except Exception as e:
            print(f'  Error processing SMILES {idx}: {smi}, error: {e}')
            continue
        
    same = len(smiles_list) == valid_count
    print(same)
    
    dataset = formDataset_Single(
        root=results_dir,
        dataset=dataset_name,
        encodedSmi=encoded_smi_list,
        ecfp=ecfp_list,
        y=labels_list,
        smile_graph=smile_graph_dict
    )
    
    return dataset

In [8]:
def list_dirs(root_dir):
    """Return sorted directory names under root_dir."""
    if not os.path.exists(root_dir):
        return []
    return sorted([
        d for d in os.listdir(root_dir)
        if os.path.isdir(os.path.join(root_dir, d))
    ])

In [9]:
TDC_PROCESSED_ROOT = "./../tdc/data/processed"
OUT_ROOT = "./data"

start_fold_num = 1
end_fold_num = 5

INPUT_COL = "smiles"     # column name in processed fold CSVs
ECPF_BITS = 2048

In [10]:
dataset_names = list_dirs(TDC_PROCESSED_ROOT)

for dataset_name in dataset_names:
    dataset_dir = os.path.join(TDC_PROCESSED_ROOT, dataset_name)
    target_cols = list_dirs(dataset_dir)

    print("\n" + "#" * 80)
    print(f"Dataset: {dataset_name}")
    print("#" * 80)

    for target_col in target_cols:
        print("\n" + "-" * 70)
        print(f"Target: {target_col}")
        print("-" * 70)

        # ------------------------------------------------------------
        # Load vocabulary (target-specific)
        # ------------------------------------------------------------
        vocab_path = os.path.join(
            OUT_ROOT, dataset_name, target_col, "smiles_char_dict.pkl"
        )

        if not os.path.exists(vocab_path):
            print(f"[SKIP] Vocabulary not found: {vocab_path}")
            continue

        with open(vocab_path, "rb") as f:
            smilesVoc = pickle.load(f)

        print(f"Loaded vocabulary size: {len(smilesVoc)}")

        # ------------------------------------------------------------
        # Fold loop
        # ------------------------------------------------------------
        for fold_num in range(start_fold_num, end_fold_num + 1):
            input_dir = os.path.join(
                TDC_PROCESSED_ROOT, dataset_name, target_col, f"fold{fold_num}"
            )
            results_dir = os.path.join(
                OUT_ROOT, dataset_name, target_col, f"fold{fold_num}"
            )
            os.makedirs(results_dir, exist_ok=True)

            # Input CSV paths
            train_csv = os.path.join(input_dir, "train.csv")
            val_csv = os.path.join(input_dir, "val.csv")
            test_csv = os.path.join(input_dir, "test.csv")

            # Basic existence check
            missing = [p for p in [train_csv, val_csv, test_csv] if not os.path.exists(p)]
            if missing:
                print(f"[SKIP] fold{fold_num} missing files: {missing}")
                continue

            print("\n" + "=" * 60)
            print(f"Processing fold{fold_num}")
            print("=" * 60)

            # Train
            _ = process_csv_to_pt(
                csv_path=train_csv,
                vocab_path=vocab_path,
                results_dir=results_dir,
                dataset_name=f"train",
                input_col=INPUT_COL,
                target_col=target_col,
                ecfp_bits=ECPF_BITS,
            )

            # Validation
            _ = process_csv_to_pt(
                csv_path=val_csv,
                vocab_path=vocab_path,
                results_dir=results_dir,
                dataset_name=f"val",
                input_col=INPUT_COL,
                target_col=target_col,
                ecfp_bits=ECPF_BITS,
            )

            # Test
            _ = process_csv_to_pt(
                csv_path=test_csv,
                vocab_path=vocab_path,
                results_dir=results_dir,
                dataset_name=f"test",
                input_col=INPUT_COL,
                target_col=target_col,
                ecfp_bits=ECPF_BITS,
            )

            print(f"fold{fold_num} completed → {results_dir}")


################################################################################
Dataset: AMES
################################################################################

----------------------------------------------------------------------
Target: AMES
----------------------------------------------------------------------
Loaded vocabulary size: 54

Processing fold1

Processing ./../tdc/data/processed/AMES/AMES/fold1/train.csv...
  Found 5223 samples
True

Processing ./../tdc/data/processed/AMES/AMES/fold1/val.csv...
  Found 581 samples
True

Processing ./../tdc/data/processed/AMES/AMES/fold1/test.csv...
  Found 1451 samples
True
fold1 completed → ./data/AMES/AMES/fold1

Processing fold2

Processing ./../tdc/data/processed/AMES/AMES/fold2/train.csv...
  Found 5223 samples
True

Processing ./../tdc/data/processed/AMES/AMES/fold2/val.csv...
  Found 581 samples
True

Processing ./../tdc/data/processed/AMES/AMES/fold2/test.csv...
  Found 1451 samples
True
fold2 completed → ./data/