In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx
from torch_geometric.loader import DataLoader
from torch_geometric import data as DATA
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr


import os

if os.path.basename(os.getcwd()) == 'notebook':
    os.chdir('..')

print(f"path: {os.getcwd()}")

from mmfdl.util.data_gen_modify import make_variable_one
from mmfdl.util.utils_smiecfp import getInput_mask
from mmfdl.model.model_combination import comModel
from mmfdl.util.utils import formDataset_Single
from mmfdl.util.normalization import LabelNormalizer

import warnings
warnings.filterwarnings('ignore')

path: /HDD1/hwan1155/work/LT-MMFE/mmfdl


In [2]:
dataset_name = 'selectivity'
task_name = 'Kd'
start_fold = 1
end_fold = 5
ecfp_bits = 2048
max_smiles_len = 44
batch_size = 256

external_csv_path = './../chembl/data/selectivity_processed/davis_selectivity.csv'
input_col = 'SMILES'
####################
target_cols=[
    'S(10uM)', 'S(1uM)', 'S(100nM)',
    'PI', 'Ssel', 'WS(alpha=2)', 'WS(alpha=1)',
    'RS(k=3)', 'RS(k=2)']

target_col = target_cols[5]
#########################

vocab_path = os.path.join('data', dataset_name, task_name, 'smiles_char_dict.pkl')
output_dir = os.path.join('results', 'SGD', dataset_name, target_col, task_name, 'external_test')
os.makedirs(output_dir, exist_ok=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'External dataset: {external_csv_path}')
print(f'Output directory: {output_dir}')

Using device: cuda:0
External dataset: ./../chembl/data/selectivity_processed/davis_selectivity.csv
Output directory: results/SGD/selectivity/WS(alpha=2)/Kd/external_test


In [None]:
argsCom = {
    'num_features_smi': 44,
    'num_features_ecfp': 2048,
    'num_features_x': 78,
    'dropout': 0.1, 
    'num_layer': 2,
    'num_heads': 2,
    'hidden_dim': 256,
    'output_dim': 128,
    'n_output': 1
}

# Load Vocabulary
with open(vocab_path, 'rb') as f:
    smilesVoc = pickle.load(f)
print(f'Vocabulary loaded: {len(smilesVoc)} characters')

# Helper functions
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(f"input {x} not in allowable set{allowable_set}:")
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def atom_features(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        [atom.GetIsAromatic()])

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None, None, []
    
    c_size = mol.GetNumAtoms()
    
    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))
    
    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])
        
    return c_size, features, edge_index

In [None]:
def process_external_csv_to_pt(csv_path, vocab_path, results_dir, dataset_name, 
                               input_col='SMILES', target_col='Ssel', 
                               max_smiles_len=44, ecfp_bits=2048):

    print(f'\nProcessing external dataset: {csv_path}...')
    
    df = pd.read_csv(csv_path)
    if input_col not in df.columns:
        raise ValueError(f"Input column '{input_col}' not found in CSV. Available columns: {df.columns.tolist()}")
    if target_col not in df.columns:
        raise ValueError(f"Target column '{target_col}' not found in CSV. Available columns: {df.columns.tolist()}")
    
    smiles_list = df[input_col].dropna().tolist()
    labels = df[target_col].dropna().tolist()

    min_len = min(len(smiles_list), len(labels))
    smiles_list = smiles_list[:min_len]
    labels = labels[:min_len]
    
    print(f'  Found {len(smiles_list)} samples')
    
    with open(vocab_path, 'rb') as f:
        smilesVoc = pickle.load(f)

    encoded_smi_list = []
    ecfp_list = []
    labels_list = []
    smile_graph_list = []
    
    valid_count = 0
    for idx, (smi, label) in enumerate(zip(smiles_list, labels)):
        if pd.isna(smi) or pd.isna(label):
            continue
        
        try:
            # SMILES encoding
            encoded_smi = make_variable_one(smi, smilesVoc, max_smiles_len)
            
            # Generate ECFP
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            if mol.HasSubstructMatch(Chem.MolFromSmarts("[H]")):
                mol = Chem.RemoveHs(mol)
            ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=ecfp_bits)
            ecfp_array = np.array(ecfp, dtype=np.float32)
            
            # Generate Graph
            c_size, features, edge_index = smile_to_graph(smi)
            if edge_index == [] or features is None:
                continue
            
            encoded_smi_list.append(encoded_smi)
            ecfp_list.append(ecfp_array.tolist())
            labels_list.append(float(label))
            smile_graph_list.append((c_size, features, edge_index))
            valid_count += 1
            
        except Exception as e:
            print(f'  Error processing SMILES {idx}: {smi}, error: {e}')
            continue
    
    print(f'  Successfully processed {valid_count} samples')
    
    smile_graph_dict = {i: smile_graph_list[i] for i in range(len(smile_graph_list))}
    
    dataset = formDataset_Single(
        root=results_dir,
        dataset=dataset_name,
        encodedSmi=encoded_smi_list,
        ecfp=ecfp_list,
        y=labels_list,
        smile_graph=smile_graph_dict
    )
    
    return dataset

external_pt_dir = os.path.join(output_dir, 'external_data')
os.makedirs(external_pt_dir, exist_ok=True)
external_dataset = process_external_csv_to_pt(
    csv_path=external_csv_path,
    vocab_path=vocab_path,
    results_dir=external_pt_dir,
    dataset_name='external_test',
    input_col=input_col,
    target_col=target_col,
    max_smiles_len=max_smiles_len,
    ecfp_bits=ecfp_bits
)
print(f'External dataset saved to: {external_pt_dir}')

In [None]:
all_fold_metrics = []

for fold_num in range(start_fold, end_fold + 1):
    print('\n' + '=' * 60)
    print(f'Processing Fold {fold_num}')
    print('=' * 60)
    
    checkpoint_dir = os.path.join('results', 'SGD', dataset_name, task_name, target_col, f'fold{fold_num}')
    
    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
    if not os.path.exists(checkpoint_path):
        print(f'Warning: Checkpoint file not found: {checkpoint_path}')
        continue
    
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    best_epoch = checkpoint['epoch']

    weight_path = os.path.join('results', 'SGD', dataset_name, task_name, target_col, f'fold{fold_num}', 
                               f'{dataset_name}_{task_name}_fold{fold_num}_weight_epoch_{best_epoch}.csv')

    normalizer_train_path = os.path.join(checkpoint_dir, 'normalizer.pkl')
    if not os.path.exists(normalizer_train_path):
        print(f'Warning: Normalizer file not found: {normalizer_train_path}')
        continue
    normalizer_train = LabelNormalizer.load(normalizer_train_path)
    print(f'Train normalizer loaded: mean={normalizer_train.mean:.4f}, std={normalizer_train.std:.4f}')
    
    print(f'Loading checkpoint: {checkpoint_path}')
    
    # Load Model & Initialize Model
    model = comModel(argsCom).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f'Model loaded from epoch {checkpoint["epoch"]}, val_loss: {checkpoint["val_loss"]:.4f}')
    
    # Load Weight
    weight_df = pd.read_csv(weight_path)
    weight_dict = dict(zip(weight_df['Key'], weight_df['Value']))
    numpy_weights = np.array([weight_dict[1], weight_dict[2], weight_dict[3]])
    print(f'Loaded weights: {numpy_weights}')
    
    external_y_labels = []
    for data in external_dataset:
        external_y_labels.append(data.y.item())
    external_y_labels = np.array(external_y_labels)
    
    normalizer_external = LabelNormalizer(mode='zscore')
    normalizer_external.fit(external_y_labels)
    print(f'External normalizer fitted: mean={normalizer_external.mean:.4f}, std={normalizer_external.std:.4f}')

    external_loader = DataLoader(external_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    
    all_predictions = []
    all_targets = []
    pred_data1 = []
    pred_data2 = []
    pred_data3 = []
    
    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(external_loader):
            encodedSmi = torch.LongTensor(data.smi).to(device)
            encodedSmi_mask = torch.LongTensor(getInput_mask(data.smi)).to(device)
            ecfp = torch.FloatTensor(data.ep).to(device)
            y = data.y.to(device)
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            batch = data.batch.to(device)

            y_norm = torch.FloatTensor(normalizer_train.transform(y.cpu().numpy())).to(device)
            
            y_pred = model(encodedSmi, encodedSmi_mask, ecfp, x, edge_index, batch)
            
            all_targets.append(y_norm.cpu().numpy())
            pred_data1.append(y_pred[0].cpu().numpy())
            pred_data2.append(y_pred[1].cpu().numpy())
            pred_data3.append(y_pred[2].cpu().numpy())

    def flattened_data(data):
        fla_data = [item for sublist in data for item in sublist]
        merged_data = np.array(fla_data).flatten()
        return merged_data
    
    y_true_norm = flattened_data(all_targets)
    y_pred_norm = numpy_weights[0] * flattened_data(pred_data1) + numpy_weights[1] * flattened_data(pred_data2) + numpy_weights[2] * flattened_data(pred_data3)
    
    y_true_train = normalizer_train.inverse_transform(y_true_norm)
    y_pred_train = normalizer_train.inverse_transform(y_pred_norm)

    y_true_external = normalizer_external.inverse_transform(y_true_norm)
    y_pred_external = normalizer_external.inverse_transform(y_pred_norm)

    rmse_train = np.sqrt(mean_squared_error(y_true_train, y_pred_train))
    r2_train = r2_score(y_true_train, y_pred_train)
    pcc_train = pearsonr(y_true_train, y_pred_train)[0]

    rmse_external = np.sqrt(mean_squared_error(y_true_external, y_pred_external))
    r2_external = r2_score(y_true_external, y_pred_external)
    pcc_external = pearsonr(y_true_external, y_pred_external)[0]
    
    fold_metrics_train = {
        'fold': fold_num,
        'inverse_basis': 'train',
        'rmse': rmse_train,
        'r2': r2_train,
        'pcc': pcc_train
    }
    fold_metrics_external = {
        'fold': fold_num,
        'inverse_basis': 'external',
        'rmse': rmse_external,
        'r2': r2_external,
        'pcc': pcc_external
    }
    all_fold_metrics.append(fold_metrics_train)
    all_fold_metrics.append(fold_metrics_external)
    
    print(f'Fold {fold_num} Metrics (train inverse):')
    print(f'  RMSE: {rmse_train:.4f}')
    print(f'  R2: {r2_train:.4f}')
    print(f'  PCC: {pcc_train:.4f}')
    print(f'Fold {fold_num} Metrics (external inverse):')
    print(f'  RMSE: {rmse_external:.4f}')
    print(f'  R2: {r2_external:.4f}')
    print(f'  PCC: {pcc_external:.4f}')

    fold_output_dir = os.path.join(output_dir, f'fold{fold_num}')
    os.makedirs(fold_output_dir, exist_ok=True)
    
    predictions_df = pd.DataFrame({
        'y_true': y_true_train,
        'y_pred': y_pred_train
    })
    predictions_path = os.path.join(fold_output_dir, 'predictions.csv')
    predictions_df.to_csv(predictions_path, index=False)
    print(f'Predictions saved: {predictions_path}')

    metric_df = pd.DataFrame([fold_metrics_train, fold_metrics_external])
    metric_path = os.path.join(fold_output_dir, 'metric.csv')
    metric_df.to_csv(metric_path, index=False)
    print(f'Metrics saved: {metric_path}')
    
    print('=' * 60)

In [None]:
if len(all_fold_metrics) > 0:
    all_metrics_df = pd.DataFrame(all_fold_metrics)

    summary_metrics_list = []
    for inverse_basis in all_metrics_df['inverse_basis'].unique():
        basis_df = all_metrics_df[all_metrics_df['inverse_basis'] == inverse_basis]
        summary_metrics = {
            'model_name': 'MMFDL',
            'inverse_basis': inverse_basis,
            'rmse': basis_df['rmse'].mean(),
            'rmse_std': basis_df['rmse'].std(), 
            'r2': basis_df['r2'].mean(),
            'r2_std': basis_df['r2'].std(),
            'pcc': basis_df['pcc'].mean(),
            'pcc_std': basis_df['pcc'].std(),
        }
        summary_metrics_list.append(pd.DataFrame([summary_metrics]))
    
    summary_df = pd.concat(summary_metrics_list, ignore_index=True)

    summary_path = os.path.join(output_dir, 'all_metric.csv')
    summary_df.to_csv(summary_path, index=False)
    print(f'\nSummary metrics saved: {summary_path}')
    print('\nSummary Metrics (External Test Set):')
    print(summary_df.to_string(index=False))

    all_metrics_path = os.path.join(output_dir, 'all_folds_metrics.csv')
    all_metrics_df.to_csv(all_metrics_path, index=False)
    print(f'\nAll folds metrics saved: {all_metrics_path}')
    
    print('\n' + '=' * 60)
    print('External test completed!')
    print('=' * 60)
else:
    print('No metrics to save!')