In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx
from mmfdl.util.data_gen_modify import make_variable_one
from mmfdl.util.utils import formDataset_Single

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

In [None]:
def atom_features(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        [atom.GetIsAromatic()])

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(f"input {x} not in allowable set{allowable_set}:")
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None, None, None
    
    c_size = mol.GetNumAtoms()
    
    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))

    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])
        
    return c_size, features, edge_index

def process_csv_to_pt(csv_path, vocab_path, results_dir, dataset_name, 
                      input_col='SMILES', target_col='Ssel', max_smiles_len=44, ecfp_bits=2048):
    """
    Generate .pt from csv
    
    Parameters:
    -----------
    csv_path : str
    vocab_path : str
    results_dir : str
    dataset_name : str
    input_col : str
    target_col : str
    max_smiles_len : int
    ecfp_bits : int
    """
    print(f'\nProcessing {csv_path}...')
    
    df = pd.read_csv(csv_path)
    if input_col not in df.columns:
        raise ValueError(f"Input column '{input_col}' not found in CSV. Available columns: {df.columns.tolist()}")
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found in CSV. Available columns: {df.columns.tolist()}")
    
    smiles_list = df[input_col].dropna().tolist()
    labels = df[target_col].dropna().tolist()
    
    min_len = min(len(smiles_list), len(labels))
    smiles_list = smiles_list[:min_len]
    labels = labels[:min_len]
    
    print(f'  Found {len(smiles_list)} samples')
    
    with open(vocab_path, 'rb') as f:
        smilesVoc = pickle.load(f)
    
    encoded_smi_list = []
    ecfp_list = []
    labels_list = []
    smile_graph_dict = {}
    
    valid_count = 0
    for idx, (smi, label) in enumerate(zip(smiles_list, labels)):
        if pd.isna(smi) or pd.isna(label):
            continue
        
        try:
            # SMILES encoding
            encoded_smi = make_variable_one(smi, smilesVoc, max_smiles_len)
            
            # Generate ECFP
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            if mol.HasSubstructMatch(Chem.MolFromSmarts("[H]")):
                mol = Chem.RemoveHs(mol)
            ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=ecfp_bits)
            ecfp_array = np.array(ecfp, dtype=np.float32)
            
            # Generate Graph
            c_size, features, edge_index = smile_to_graph(smi)
            if edge_index == [] or features is None:
                continue
            
            encoded_smi_list.append(encoded_smi)
            ecfp_list.append(ecfp_array.tolist())
            labels_list.append(float(label))
            smile_graph_dict[valid_count] = (c_size, features, edge_index)
            valid_count += 1
            
        except Exception as e:
            print(f'  Error processing SMILES {idx}: {smi}, error: {e}')
            continue
        
    same = len(smiles_list) == valid_count
    print(same)
    
    dataset = formDataset_Single(
        root=results_dir,
        dataset=dataset_name,
        encodedSmi=encoded_smi_list,
        ecfp=ecfp_list,
        y=labels_list,
        smile_graph=smile_graph_dict
    )
    
    return dataset

### Load Vocabulary

In [3]:
dataset_name = 'selectivity'
task_name = 'Ki'
vocab_path = f'./data/{dataset_name}/{task_name}/smiles_char_dict.pkl'

with open(vocab_path, 'rb') as f:
    smilesVoc = pickle.load(f)

In [4]:
start_fold_num = 1
end_fold_num = 5

for fold_num in range(start_fold_num, end_fold_num + 1):
    input_dir = f'/home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold{fold_num}'
    results_dir = f'./data/{dataset_name}/{task_name}/fold{fold_num}'
    os.makedirs(results_dir, exist_ok=True)

    # input file path
    train_csv = f'{input_dir}/{dataset_name}_train.csv'
    val_csv = f'{input_dir}/{dataset_name}_val.csv'
    test_csv = f'{input_dir}/{dataset_name}_test.csv'

    # Results file path
    train_pt = f'{results_dir}/{dataset_name}_train.pt'
    val_pt = f'{results_dir}/{dataset_name}_val.pt'
    test_pt = f'{results_dir}/{dataset_name}_test.pt'

    # Generate Train, Validation, Test
    print('\n' + '=' * 60)
    print(f'Processing Fold {fold_num}')
    print('=' * 60)

    train_dataset = process_csv_to_pt(
        csv_path=train_csv,
        vocab_path=vocab_path,
        results_dir=results_dir,
        dataset_name=f'{dataset_name}_train',
        input_col='SMILES',
        target_col='Ssel',
        ecfp_bits=2048
    )

    val_dataset = process_csv_to_pt(
        csv_path=val_csv,
        vocab_path=vocab_path,
        results_dir=results_dir,
        dataset_name=f'{dataset_name}_val',
        input_col='SMILES',
        target_col='Ssel',
        ecfp_bits=2048
    )

    test_dataset = process_csv_to_pt(
        csv_path=test_csv,
        vocab_path=vocab_path,
        results_dir=results_dir,
        dataset_name=f'{dataset_name}_test',
        input_col='SMILES',
        target_col='Ssel',
        ecfp_bits=2048
    )

    print(f'\nFold {fold_num} completed!')
    print(f'Results saved to: {results_dir}')


Processing Fold 1

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold1/selectivity_train.csv...
  Found 1313 samples
True

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold1/selectivity_val.csv...
  Found 146 samples
True

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold1/selectivity_test.csv...
  Found 365 samples
True

Fold 1 completed!
Results saved to: ./data/selectivity/Ki/fold1

Processing Fold 2

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold2/selectivity_train.csv...
  Found 1313 samples
True

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold2/selectivity_val.csv...
  Found 146 samples
True

Processing /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Ki/fold2/selectivity_test.csv...
  Found 365 samples
True

Fold 2 completed!
Results saved to: ./data/selectivity/Ki/fold2

Processing Fold 3

Processing /