In [1]:
import os
import pickle

from typing import Optional, Tuple, List

import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx
from mmfdl.util.data_gen_modify import make_variable_one
from mmfdl.util.utils import formDataset_Single

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

In [2]:
def atom_features(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        [atom.GetIsAromatic()])

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(f"input {x} not in allowable set{allowable_set}:")
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None, None, None
    
    c_size = mol.GetNumAtoms()
    
    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))

    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])
        
    return c_size, features, edge_index

def infer_max_smiles_len(
    smiles_list: List[str],
    max_cap: int = 256
) -> int:
    """
    Infer max SMILES length from dataset.

    Args:
        smiles_list (List[str]): List of SMILES strings.
        max_cap (int): Upper cap to avoid extremely long outliers.

    Returns:
        int: Inferred max length (capped).
    """
    lengths = [len(s) for s in smiles_list if isinstance(s, str) and len(s) > 0]
    if len(lengths) == 0:
        return 0
    inferred = int(max(lengths))
    return int(min(inferred, max_cap))


def process_csv_to_pt(
    csv_path: str,
    vocab_path: str,
    results_dir: str,
    dataset_name: str,
    input_col: str = "SMILES",
    target_col: str = "Ssel",
    max_smiles_len: Optional[int] = None,
    ecfp_bits: int = 2048,
    max_cap: int = 256
):
    """
    Generate .pt from csv.

    Args:
        csv_path (str): Path to csv file.
        vocab_path (str): Path to smiles vocab pkl.
        results_dir (str): Output directory.
        dataset_name (str): Dataset name for pt.
        input_col (str): SMILES column name.
        target_col (str): Target column name.
        max_smiles_len (Optional[int]): If None, infer from dataset (capped).
        ecfp_bits (int): ECFP bit size.
        max_cap (int): Upper cap for inferred max length.

    Returns:
        Dataset: formDataset_Single instance.
    """
    print(f"\nProcessing {csv_path}...")

    df = pd.read_csv(csv_path)

    if input_col not in df.columns:
        raise ValueError(f"Input column '{input_col}' not found. Available: {df.columns.tolist()}")
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found. Available: {df.columns.tolist()}")

    smiles_list = df[input_col].dropna().astype(str).tolist()
    labels = df[target_col].dropna().tolist()

    min_len = min(len(smiles_list), len(labels))
    smiles_list = smiles_list[:min_len]
    labels = labels[:min_len]

    if max_smiles_len is None:
        max_smiles_len = infer_max_smiles_len(smiles_list, max_cap=max_cap)
        if max_smiles_len <= 0:
            raise ValueError("Failed to infer max_smiles_len (no valid SMILES).")
        print(f"[INFO] Auto max_smiles_len inferred: {max_smiles_len} (cap={max_cap})")
    else:
        print(f"[INFO] Using provided max_smiles_len: {max_smiles_len}")

    with open(vocab_path, "rb") as f:
        smilesVoc = pickle.load(f)

    encoded_smi_list = []
    ecfp_list = []
    labels_list = []
    smile_graph_dict = {}

    valid_count = 0
    for idx, (smi, label) in enumerate(zip(smiles_list, labels)):
        if pd.isna(smi) or pd.isna(label):
            continue

        try:
            encoded_smi = make_variable_one(smi, smilesVoc, max_smiles_len)

            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            if mol.HasSubstructMatch(Chem.MolFromSmarts("[H]")):
                mol = Chem.RemoveHs(mol)

            ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=ecfp_bits)
            ecfp_array = np.array(ecfp, dtype=np.float32)

            c_size, features, edge_index = smile_to_graph(smi)
            if edge_index == [] or features is None:
                continue

            encoded_smi_list.append(encoded_smi)
            ecfp_list.append(ecfp_array.tolist())
            labels_list.append(float(label))
            smile_graph_dict[valid_count] = (c_size, features, edge_index)
            valid_count += 1

        except Exception as e:
            print(f"  Error processing SMILES {idx}: {smi}, error: {e}")
            continue

    print(f"[INFO] valid_count={valid_count} / total={len(smiles_list)}")

    dataset = formDataset_Single(
        root=results_dir,
        dataset=dataset_name,
        encodedSmi=encoded_smi_list,
        ecfp=ecfp_list,
        y=labels_list,
        smile_graph=smile_graph_dict,
    )

    return dataset

### Load Vocabulary

In [3]:
# dataset_name = 'selectivity'
# task_name = 'Ki'
# vocab_path = f'./data/{dataset_name}/{task_name}/smiles_char_dict.pkl'

# with open(vocab_path, 'rb') as f:
#     smilesVoc = pickle.load(f)

In [4]:
# start_fold_num = 1
# end_fold_num = 5

# for fold_num in range(start_fold_num, end_fold_num + 1):
#     input_dir = f'/home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold{fold_num}'
#     results_dir = f'./data/{dataset_name}/{task_name}/fold{fold_num}'
#     os.makedirs(results_dir, exist_ok=True)

#     # input file path
#     train_csv = f'{input_dir}/{dataset_name}_train.csv'
#     val_csv = f'{input_dir}/{dataset_name}_val.csv'
#     test_csv = f'{input_dir}/{dataset_name}_test.csv'

#     # Results file path
#     train_pt = f'{results_dir}/{dataset_name}_train.pt'
#     val_pt = f'{results_dir}/{dataset_name}_val.pt'
#     test_pt = f'{results_dir}/{dataset_name}_test.pt'

#     # Generate Train, Validation, Test
#     print('\n' + '=' * 60)
#     print(f'Processing Fold {fold_num}')
#     print('=' * 60)

#     train_dataset = process_csv_to_pt(
#         csv_path=train_csv,
#         vocab_path=vocab_path,
#         results_dir=results_dir,
#         dataset_name=f'{dataset_name}_train',
#         input_col='SMILES',
#         target_col='Ssel',
#         ecfp_bits=2048
#     )

#     val_dataset = process_csv_to_pt(
#         csv_path=val_csv,
#         vocab_path=vocab_path,
#         results_dir=results_dir,
#         dataset_name=f'{dataset_name}_val',
#         input_col='SMILES',
#         target_col='Ssel',
#         ecfp_bits=2048
#     )

#     test_dataset = process_csv_to_pt(
#         csv_path=test_csv,
#         vocab_path=vocab_path,
#         results_dir=results_dir,
#         dataset_name=f'{dataset_name}_test',
#         input_col='SMILES',
#         target_col='Ssel',
#         ecfp_bits=2048
#     )

#     print(f'\nFold {fold_num} completed!')
#     print(f'Results saved to: {results_dir}')

In [5]:
def infer_max_smiles_len_from_original_csv(
    original_csv_path: str,
    smiles_col: str = "SMILES",
    cap: int = 256,
    percentile: Optional[float] = None,
) -> int:
    """
    Infer a single max_smiles_len from the original (pre-split) dataset.

    Args:
        original_csv_path (str): Path to the original dataset CSV (before fold split).
        smiles_col (str): Column name for SMILES.
        cap (int): Upper cap to avoid extremely long outliers.
        percentile (Optional[float]): If provided (e.g., 95.0), use that percentile length
            instead of absolute max, then apply cap. If None, use absolute max (then cap).

    Returns:
        int: max_smiles_len to be used consistently across all folds/splits.
    """
    df = pd.read_csv(original_csv_path)
    if smiles_col not in df.columns:
        raise ValueError(f"'{smiles_col}' not found in {original_csv_path}. Available: {df.columns.tolist()}")

    smiles_list = df[smiles_col].dropna().astype(str).tolist()
    lengths = [len(s) for s in smiles_list if len(s) > 0]

    if len(lengths) == 0:
        raise ValueError("No valid SMILES found to infer max_smiles_len.")

    if percentile is None:
        inferred = int(max(lengths))
    else:
        inferred = int(np.percentile(lengths, percentile))

    inferred = int(min(inferred, cap))
    return inferred


def build_fold_pt_with_global_maxlen(
    fold_csv_dir_root: str,
    results_dir_root: str,
    dataset_name: str,
    task_name: str,
    vocab_path: str,
    global_max_smiles_len: int,
    input_col: str = "SMILES",
    target_col: str = "Ssel",
    ecfp_bits: int = 2048,
    start_fold: int = 1,
    end_fold: int = 5,
) -> None:
    """
    Generate train/val/test .pt for each fold using a single global max_smiles_len.

    Args:
        fold_csv_dir_root (str): Root dir containing fold{n} directories with CSV files.
        results_dir_root (str): Root output dir where fold{n} directories will be created.
        dataset_name (str): Dataset name prefix.
        task_name (str): Task name.
        vocab_path (str): Path to smiles vocab pkl.
        global_max_smiles_len (int): Global max length inferred from original dataset.
        input_col (str): SMILES column name.
        target_col (str): Target column name.
        ecfp_bits (int): ECFP bit size.
        start_fold (int): Start fold index.
        end_fold (int): End fold index.

    Returns:
        None
    """
    for fold_num in range(start_fold, end_fold + 1):
        input_dir = os.path.join(fold_csv_dir_root, f"fold{fold_num}")
        results_dir = os.path.join(results_dir_root, f"fold{fold_num}")
        os.makedirs(results_dir, exist_ok=True)

        train_csv = os.path.join(input_dir, f"{dataset_name}_train.csv")
        val_csv = os.path.join(input_dir, f"{dataset_name}_val.csv")
        test_csv = os.path.join(input_dir, f"{dataset_name}_test.csv")

        print("\n" + "=" * 60)
        print(f"Processing Fold {fold_num} | global_max_smiles_len={global_max_smiles_len}")
        print("=" * 60)

        _ = process_csv_to_pt(
            csv_path=train_csv,
            vocab_path=vocab_path,
            results_dir=results_dir,
            dataset_name=f"{dataset_name}_train",
            input_col=input_col,
            target_col=target_col,
            max_smiles_len=global_max_smiles_len,
            ecfp_bits=ecfp_bits,
        )

        _ = process_csv_to_pt(
            csv_path=val_csv,
            vocab_path=vocab_path,
            results_dir=results_dir,
            dataset_name=f"{dataset_name}_val",
            input_col=input_col,
            target_col=target_col,
            max_smiles_len=global_max_smiles_len,
            ecfp_bits=ecfp_bits,
        )

        _ = process_csv_to_pt(
            csv_path=test_csv,
            vocab_path=vocab_path,
            results_dir=results_dir,
            dataset_name=f"{dataset_name}_test",
            input_col=input_col,
            target_col=target_col,
            max_smiles_len=global_max_smiles_len,
            ecfp_bits=ecfp_bits,
        )

        print(f"\nFold {fold_num} completed! Saved to: {results_dir}")


In [6]:
dataset_name = "selectivity"
task_name = "Ki"

# (1) 원본(프리-스플릿) CSV 경로: 사용 중인 파이프라인에 맞게 한 파일을 지정해야 함
raw_csv_path = f'/home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/{task_name}_selectivity.csv'

# (2) global max_len 산출 (max 대신 95 percentile도 가능)
global_max_len = infer_max_smiles_len_from_original_csv(
    original_csv_path=raw_csv_path,
    smiles_col="SMILES",
    cap=256,
    percentile=95.0,   # absolute max를 원하면 None
)

# (3) fold별 CSV가 있는 루트와 결과 루트 지정
fold_csv_root = f"/home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/{task_name}"
results_root = f"./data/{dataset_name}/{task_name}"

# (4) 생성 실행
build_fold_pt_with_global_maxlen(
    fold_csv_dir_root=fold_csv_root,
    results_dir_root=results_root,
    dataset_name=dataset_name,
    task_name=task_name,
    vocab_path=f"./data/{dataset_name}/{task_name}/smiles_char_dict.pkl",
    global_max_smiles_len=global_max_len,
    input_col="SMILES",
    target_col="Ssel",
    ecfp_bits=2048,
    start_fold=1,
    end_fold=5,
)


Processing Fold 1 | global_max_smiles_len=69

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold1/selectivity_train.csv...
[INFO] Using provided max_smiles_len: 69
[INFO] valid_count=1313 / total=1313

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold1/selectivity_val.csv...
[INFO] Using provided max_smiles_len: 69
[INFO] valid_count=146 / total=146

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold1/selectivity_test.csv...
[INFO] Using provided max_smiles_len: 69
[INFO] valid_count=365 / total=365

Fold 1 completed! Saved to: ./data/selectivity/Ki/fold1

Processing Fold 2 | global_max_smiles_len=69

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold2/selectivity_train.csv...
[INFO] Using provided max_smiles_len: 69
[INFO] valid_count=1313 / total=1313

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold2/selectivity_val.csv...
[IN